From 8cbbbd950e34d8b161d70cf4348e1e2851c50d9d Mon Sep 17 00:00:00 2001 From: fary86 Date: Fri, 17 Apr 2020 01:12:34 +0800 Subject: [PATCH 001/142] Add cell name to error message --- mindspore/_checkparam.py | 128 ++++++++++++++++-- mindspore/nn/cell.py | 4 + mindspore/nn/dynamic_lr.py | 46 +++---- mindspore/nn/layer/basic.py | 6 +- mindspore/nn/layer/embedding.py | 4 +- mindspore/nn/layer/image.py | 22 +-- mindspore/nn/layer/lstm.py | 4 +- mindspore/nn/layer/pooling.py | 51 +++---- mindspore/nn/metrics/fbeta.py | 4 +- mindspore/nn/metrics/precision.py | 4 +- mindspore/nn/metrics/recall.py | 4 +- mindspore/nn/optim/adam.py | 34 ++--- mindspore/nn/optim/ftrl.py | 38 +++--- mindspore/nn/optim/lamb.py | 34 ++--- mindspore/nn/optim/optimizer.py | 4 +- mindspore/nn/optim/rmsprop.py | 6 +- mindspore/nn/optim/sgd.py | 4 +- mindspore/ops/op_info_register.py | 4 +- mindspore/train/amp.py | 18 +-- mindspore/train/loss_scale_manager.py | 4 +- tests/ut/python/nn/test_dynamic_lr.py | 12 +- tests/ut/python/nn/test_psnr.py | 2 +- tests/ut/python/nn/test_ssim.py | 4 +- tests/ut/python/ops/test_nn_ops.py | 4 +- .../python/pynative_mode/nn/test_pooling.py | 2 +- 25 files changed, 272 insertions(+), 175 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index cb3dbc0d50..e9a928461f 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -17,7 +17,7 @@ import re from enum import Enum from functools import reduce from itertools import repeat -from collections import Iterable +from collections.abc import Iterable import numpy as np from mindspore import log as logger @@ -98,7 +98,7 @@ class Validator: """validator for checking input parameters""" @staticmethod - def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None): + def check(arg_name, arg_value, value_name, value, rel=Rel.EQ, prim_name=None, excp_cls=ValueError): """ Method for judging relation between two int values or list/tuple made up of ints. @@ -108,8 +108,8 @@ class Validator: rel_fn = Rel.get_fns(rel) if not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(f'{value_name}: {value}') - msg_prefix = f'For {prim_name} the' if prim_name else "The" - raise ValueError(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + raise excp_cls(f'{msg_prefix} `{arg_name}` should be {rel_str}, but got {arg_value}.') @staticmethod def check_integer(arg_name, arg_value, value, rel, prim_name): @@ -118,8 +118,17 @@ class Validator: type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) - raise ValueError(f'For {prim_name} the `{arg_name}` should be an int and must {rel_str},' - f' but got {arg_value}.') + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') + return arg_value + + @staticmethod + def check_number(arg_name, arg_value, value, rel, prim_name): + """Integer value judgment.""" + rel_fn = Rel.get_fns(rel) + if not rel_fn(arg_value, value): + rel_str = Rel.get_strs(rel).format(value) + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must {rel_str}, but got {arg_value}.') return arg_value @staticmethod @@ -133,9 +142,46 @@ class Validator: f' but got {arg_value}.') return arg_value + @staticmethod + def check_number_range(arg_name, arg_value, lower_limit, upper_limit, rel, prim_name): + """Method for checking whether a numeric value is in some range.""" + rel_fn = Rel.get_fns(rel) + if not rel_fn(arg_value, lower_limit, upper_limit): + rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be in range {rel_str}, but got {arg_value}.') + return arg_value + + @staticmethod + def check_string(arg_name, arg_value, valid_values, prim_name): + """Checks whether a string is in some value list""" + if isinstance(arg_value, str) and arg_value in valid_values: + return arg_value + if len(valid_values) == 1: + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be {valid_values[0]},' + f' but got {arg_value}.') + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be str and must be one of {valid_values},' + f' but got {arg_value}.') + + @staticmethod + def check_pad_value_by_mode(pad_mode, padding, prim_name): + """Validates value of padding according to pad_mode""" + if pad_mode != 'pad' and padding != 0: + raise ValueError(f"For '{prim_name}', padding must be zero when pad_mode is '{pad_mode}'.") + return padding + + @staticmethod + def check_float_positive(arg_name, arg_value, prim_name): + """Float type judgment.""" + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + if isinstance(arg_value, float): + if arg_value > 0: + return arg_value + raise ValueError(f"{msg_prefix} `{arg_name}` must be positive, but got {arg_value}.") + raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") + @staticmethod def check_subclass(arg_name, type_, template_type, prim_name): - """Check whether some type is sublcass of another type""" + """Checks whether some type is sublcass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -143,16 +189,44 @@ class Validator: raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be subclass' f' of {",".join((str(x) for x in template_type))}, but got {type_str}.') + @staticmethod + def check_const_input(arg_name, arg_value, prim_name): + """Check valid value.""" + if arg_value is None: + raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') + + @staticmethod + def check_scalar_type_same(args, valid_values, prim_name): + """check whether the types of inputs are the same.""" + def _check_tensor_type(arg): + arg_key, arg_val = arg + elem_type = arg_val + if not elem_type in valid_values: + raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},' + f' but `{arg_key}` is {elem_type}.') + return (arg_key, elem_type) + + def _check_types_same(arg1, arg2): + arg1_name, arg1_type = arg1 + arg2_name, arg2_type = arg2 + if arg1_type != arg2_type: + raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' + f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') + return arg1 + + elem_types = map(_check_tensor_type, args.items()) + reduce(_check_types_same, elem_types) + @staticmethod def check_tensor_type_same(args, valid_values, prim_name): - """check whether the element types of input tensors are the same.""" + """Checks whether the element types of input tensors are the same.""" def _check_tensor_type(arg): arg_key, arg_val = arg Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) elem_type = arg_val.element_type() if not elem_type in valid_values: raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' - f' but `{arg_key}` is {elem_type}.') + f' but element type of `{arg_key}` is {elem_type}.') return (arg_key, elem_type) def _check_types_same(arg1, arg2): @@ -168,8 +242,13 @@ class Validator: @staticmethod - def check_scalar_or_tensor_type_same(args, valid_values, prim_name): - """check whether the types of inputs are the same. if the input args are tensors, check their element types""" + def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): + """ + Checks whether the types of inputs are the same. If the input args are tensors, checks their element types. + + If `allow_mix` is True, Tensor(float32) and float32 are type compatible, otherwise an exception will be raised. + """ + def _check_argument_type(arg): arg_key, arg_val = arg if isinstance(arg_val, type(mstype.tensor)): @@ -188,6 +267,9 @@ class Validator: arg2_type = arg2_type.element_type() elif not (isinstance(arg1_type, type(mstype.tensor)) or isinstance(arg2_type, type(mstype.tensor))): pass + elif allow_mix: + arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type + arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type else: excp_flag = True @@ -199,13 +281,14 @@ class Validator: @staticmethod def check_value_type(arg_name, arg_value, valid_types, prim_name): - """Check whether a values is instance of some types.""" + """Checks whether a value is instance of some types.""" + valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) def raise_error_msg(): """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) - raise TypeError(f'For \'{prim_name}\' the type of `{arg_name}` should be ' - f'{"one of " if num_types > 1 else ""}' + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') # Notice: bool is subclass of int, so `check_value_type('x', True, [int])` will check fail, and @@ -216,6 +299,23 @@ class Validator: return arg_value raise_error_msg() + @staticmethod + def check_type_name(arg_name, arg_type, valid_types, prim_name): + """Checks whether a type in some specified types""" + valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + def get_typename(t): + return t.__name__ if hasattr(t, '__name__') else str(t) + + if arg_type in valid_types: + return arg_type + type_names = [get_typename(t) for t in valid_types] + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + if len(valid_types) == 1: + raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' + f' but got {get_typename(arg_type)}.') + raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' + f' but got {get_typename(arg_type)}.') + class ParamValidator: """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 088f3f3e57..3fda46e7bb 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -103,6 +103,10 @@ class Cell: def parameter_layout_dict(self): return self._parameter_layout_dict + @property + def cls_name(self): + return self.__class__.__name__ + @parameter_layout_dict.setter def parameter_layout_dict(self, value): if not isinstance(value, dict): diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index cf25f1f50e..0c5a160380 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -15,7 +15,7 @@ """dynamic learning rate""" import math -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel @@ -43,16 +43,16 @@ def piecewise_constant_lr(milestone, learning_rates): >>> lr = piecewise_constant_lr(milestone, learning_rates) [0.1, 0.1, 0.05, 0.05, 0.05, 0.01, 0.01, 0.01, 0.01, 0.01] """ - validator.check_type('milestone', milestone, (tuple, list)) - validator.check_type('learning_rates', learning_rates, (tuple, list)) + validator.check_value_type('milestone', milestone, (tuple, list), None) + validator.check_value_type('learning_rates', learning_rates, (tuple, list), None) if len(milestone) != len(learning_rates): raise ValueError('The size of `milestone` must be same with the size of `learning_rates`.') lr = [] last_item = 0 for i, item in enumerate(milestone): - validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT) - validator.check_type(f'learning_rates[{i}]', learning_rates[i], [float]) + validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) + validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None) if item < last_item: raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') lr += [learning_rates[i]] * (item - last_item) @@ -62,12 +62,12 @@ def piecewise_constant_lr(milestone, learning_rates): def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair): - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) - validator.check_float_positive('learning_rate', learning_rate) - validator.check_float_positive('decay_rate', decay_rate) - validator.check_type('is_stair', is_stair, [bool]) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_positive('decay_rate', decay_rate, None) + validator.check_value_type('is_stair', is_stair, [bool], None) def exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair=False): @@ -228,11 +228,11 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): >>> lr = cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch) [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] """ - validator.check_float_positive('min_lr', min_lr) - validator.check_float_positive('max_lr', max_lr) - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) + validator.check_float_positive('min_lr', min_lr, None) + validator.check_float_positive('max_lr', max_lr, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) delta = 0.5 * (max_lr - min_lr) lr = [] @@ -279,13 +279,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e >>> lr = polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power) [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] """ - validator.check_float_positive('learning_rate', learning_rate) - validator.check_float_positive('end_learning_rate', end_learning_rate) - validator.check_integer('total_step', total_step, 0, Rel.GT) - validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT) - validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT) - validator.check_type('power', power, [float]) - validator.check_type('update_decay_epoch', update_decay_epoch, [bool]) + validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_positive('end_learning_rate', end_learning_rate, None) + validator.check_integer('total_step', total_step, 0, Rel.GT, None) + validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) + validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + validator.check_value_type('power', power, [float], None) + validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) function = lambda x, y: (x, min(x, y)) if update_decay_epoch: diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 5ac52acac7..2449eea9b4 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -25,7 +25,7 @@ from mindspore.common.parameter import Parameter from mindspore._extends import cell_attr_register from ..cell import Cell from .activation import get_activation -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator class Dropout(Cell): @@ -73,7 +73,7 @@ class Dropout(Cell): super(Dropout, self).__init__() if keep_prob <= 0 or keep_prob > 1: raise ValueError("dropout probability should be a number in range (0, 1], but got {}".format(keep_prob)) - validator.check_subclass("dtype", dtype, mstype.number_type) + validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.keep_prob = Tensor(keep_prob) self.seed0 = seed0 self.seed1 = seed1 @@ -421,7 +421,7 @@ class Pad(Cell): super(Pad, self).__init__() self.mode = mode self.paddings = paddings - validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"]) + validator.check_string('mode', self.mode, ["CONSTANT", "REFLECT", "SYMMETRIC"], self.cls_name) if not isinstance(paddings, tuple): raise TypeError('Paddings must be tuple type.') for item in paddings: diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index dfa8e66469..24b94f2f3c 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -19,7 +19,7 @@ from mindspore.ops import operations as P from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from ..cell import Cell -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator class Embedding(Cell): @@ -59,7 +59,7 @@ class Embedding(Cell): """ def __init__(self, vocab_size, embedding_size, use_one_hot=False, embedding_table='normal', dtype=mstype.float32): super(Embedding, self).__init__() - validator.check_subclass("dtype", dtype, mstype.number_type) + validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) self.vocab_size = vocab_size self.embedding_size = embedding_size self.use_one_hot = use_one_hot diff --git a/mindspore/nn/layer/image.py b/mindspore/nn/layer/image.py index 72c4c6d8e2..b46ac4cd6e 100644 --- a/mindspore/nn/layer/image.py +++ b/mindspore/nn/layer/image.py @@ -19,7 +19,7 @@ from mindspore.common.tensor import Tensor from mindspore.ops import operations as P from mindspore.ops import functional as F from mindspore.ops.primitive import constexpr -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from ..cell import Cell @@ -134,15 +134,15 @@ class SSIM(Cell): """ def __init__(self, max_val=1.0, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03): super(SSIM, self).__init__() - validator.check_type('max_val', max_val, [int, float]) - validator.check('max_val', max_val, '', 0.0, Rel.GT) + validator.check_value_type('max_val', max_val, [int, float], self.cls_name) + validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val - self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE) - self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma) - validator.check_type('k1', k1, [float]) - self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_type('k2', k2, [float]) - self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER) + self.filter_size = validator.check_integer('filter_size', filter_size, 1, Rel.GE, self.cls_name) + self.filter_sigma = validator.check_float_positive('filter_sigma', filter_sigma, self.cls_name) + validator.check_value_type('k1', k1, [float], self.cls_name) + self.k1 = validator.check_number_range('k1', k1, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) + validator.check_value_type('k2', k2, [float], self.cls_name) + self.k2 = validator.check_number_range('k2', k2, 0.0, 1.0, Rel.INC_NEITHER, self.cls_name) self.mean = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=filter_size) def construct(self, img1, img2): @@ -231,8 +231,8 @@ class PSNR(Cell): """ def __init__(self, max_val=1.0): super(PSNR, self).__init__() - validator.check_type('max_val', max_val, [int, float]) - validator.check('max_val', max_val, '', 0.0, Rel.GT) + validator.check_value_type('max_val', max_val, [int, float], self.cls_name) + validator.check_number('max_val', max_val, 0.0, Rel.GT, self.cls_name) self.max_val = max_val def construct(self, img1, img2): diff --git a/mindspore/nn/layer/lstm.py b/mindspore/nn/layer/lstm.py index cef926d365..84c156a1c2 100755 --- a/mindspore/nn/layer/lstm.py +++ b/mindspore/nn/layer/lstm.py @@ -17,7 +17,7 @@ from mindspore.ops import operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator class LSTM(Cell): @@ -114,7 +114,7 @@ class LSTM(Cell): self.hidden_size = hidden_size self.num_layers = num_layers self.has_bias = has_bias - self.batch_first = validator.check_type("batch_first", batch_first, [bool]) + self.batch_first = validator.check_value_type("batch_first", batch_first, [bool], self.cls_name) self.dropout = float(dropout) self.bidirectional = bidirectional diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 746b6d240f..53d97807cf 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -14,8 +14,7 @@ # ============================================================================ """pooling""" from mindspore.ops import operations as P -from mindspore._checkparam import ParamValidator as validator -from mindspore._checkparam import Rel +from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell @@ -24,35 +23,27 @@ class _PoolNd(Cell): """N-D AvgPool""" def __init__(self, kernel_size, stride, pad_mode): - name = self.__class__.__name__ super(_PoolNd, self).__init__() - validator.check_type('kernel_size', kernel_size, [int, tuple]) - validator.check_type('stride', stride, [int, tuple]) - self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) - - if isinstance(kernel_size, int): - validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) - else: - if (len(kernel_size) != 2 or - (not isinstance(kernel_size[0], int)) or - (not isinstance(kernel_size[1], int)) or - kernel_size[0] <= 0 or - kernel_size[1] <= 0): - raise ValueError(f'The kernel_size passed to cell {name} should be an positive int number or' - f'a tuple of two positive int numbers, but got {kernel_size}') - self.kernel_size = kernel_size - - if isinstance(stride, int): - validator.check_integer("stride", stride, 1, Rel.GE) - else: - if (len(stride) != 2 or - (not isinstance(stride[0], int)) or - (not isinstance(stride[1], int)) or - stride[0] <= 0 or - stride[1] <= 0): - raise ValueError(f'The stride passed to cell {name} should be an positive int number or' - f'a tuple of two positive int numbers, but got {stride}') - self.stride = stride + self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name) + + def _check_int_or_tuple(arg_name, arg_value): + validator.check_value_type(arg_name, arg_value, [int, tuple], self.cls_name) + error_msg = f'For \'{self.cls_name}\' the {arg_name} should be an positive int number or ' \ + f'a tuple of two positive int numbers, but got {arg_value}' + if isinstance(arg_value, int): + if arg_value <= 0: + raise ValueError(error_msg) + elif len(arg_value) == 2: + for item in arg_value: + if isinstance(item, int) and item > 0: + continue + raise ValueError(error_msg) + else: + raise ValueError(error_msg) + return arg_value + + self.kernel_size = _check_int_or_tuple('kernel_size', kernel_size) + self.stride = _check_int_or_tuple('stride', stride) def construct(self, *inputs): pass diff --git a/mindspore/nn/metrics/fbeta.py b/mindspore/nn/metrics/fbeta.py index 68df4318b0..3ae5c44bc2 100755 --- a/mindspore/nn/metrics/fbeta.py +++ b/mindspore/nn/metrics/fbeta.py @@ -15,7 +15,7 @@ """Fbeta.""" import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .metric import Metric @@ -104,7 +104,7 @@ class Fbeta(Metric): Returns: Float, computed result. """ - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') diff --git a/mindspore/nn/metrics/precision.py b/mindspore/nn/metrics/precision.py index ad7b6c576f..633b9f8e2c 100644 --- a/mindspore/nn/metrics/precision.py +++ b/mindspore/nn/metrics/precision.py @@ -17,7 +17,7 @@ import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .evaluation import EvaluationBase @@ -136,7 +136,7 @@ class Precision(EvaluationBase): if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) result = self._true_positives / (self._positives + self.eps) if average: diff --git a/mindspore/nn/metrics/recall.py b/mindspore/nn/metrics/recall.py index 45ebf0d7db..da06321aa3 100644 --- a/mindspore/nn/metrics/recall.py +++ b/mindspore/nn/metrics/recall.py @@ -17,7 +17,7 @@ import sys import numpy as np -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .evaluation import EvaluationBase @@ -136,7 +136,7 @@ class Recall(EvaluationBase): if self._class_num == 0: raise RuntimeError('Input number of samples can not be 0.') - validator.check_type("average", average, [bool]) + validator.check_value_type("average", average, [bool], self.__class__.__name__) result = self._true_positives / (self._actual_positives + self.eps) if average: diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index eb4e33751f..65f8ec678b 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -22,7 +22,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer @@ -78,16 +78,16 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad return next_v -def _check_param_value(beta1, beta2, eps, weight_decay): +def _check_param_value(beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - validator.check_type("beta1", beta1, [float]) - validator.check_type("beta2", beta2, [float]) - validator.check_type("eps", eps, [float]) - validator.check_type("weight_dacay", weight_decay, [float]) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) @adam_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", @@ -168,11 +168,11 @@ class Adam(Optimizer): use_nesterov=False, weight_decay=0.0, loss_scale=1.0, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): super(Adam, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) - _check_param_value(beta1, beta2, eps, weight_decay) - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("use_nesterov", use_nesterov, [bool]) - validator.check_type("loss_scale", loss_scale, [float]) - validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.cls_name) + validator.check_value_type("loss_scale", loss_scale, [float], self.cls_name) + validator.check_number_range("loss_scale", loss_scale, 1.0, float("inf"), Rel.INC_LEFT, self.cls_name) self.beta1 = Tensor(beta1, mstype.float32) self.beta2 = Tensor(beta2, mstype.float32) @@ -241,7 +241,7 @@ class AdamWeightDecay(Optimizer): """ def __init__(self, params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-6, weight_decay=0.0): super(AdamWeightDecay, self).__init__(learning_rate, params) - _check_param_value(beta1, beta2, eps, weight_decay) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) self.lr = Tensor(np.array([learning_rate]).astype(np.float32)) self.beta1 = Tensor(np.array([beta1]).astype(np.float32)) self.beta2 = Tensor(np.array([beta2]).astype(np.float32)) @@ -304,7 +304,7 @@ class AdamWeightDecayDynamicLR(Optimizer): eps=1e-6, weight_decay=0.0): super(AdamWeightDecayDynamicLR, self).__init__(learning_rate, params) - _check_param_value(beta1, beta2, eps, weight_decay) + _check_param_value(beta1, beta2, eps, weight_decay, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations self.global_step = Parameter(initializer(0, [1]), name="global_step") diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index ee8fc9355f..d08dd6cf4c 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -18,7 +18,7 @@ from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.common import Tensor import mindspore.common.dtype as mstype -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer, apply_decay, grad_scale @@ -30,29 +30,30 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) return success -def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0): - validator.check_type("initial_accum", initial_accum, [float]) - validator.check("initial_accum", initial_accum, "", 0.0, Rel.GE) +def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0, + prim_name=None): + validator.check_value_type("initial_accum", initial_accum, [float], prim_name) + validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) - validator.check_type("learning_rate", learning_rate, [float]) - validator.check("learning_rate", learning_rate, "", 0.0, Rel.GT) + validator.check_value_type("learning_rate", learning_rate, [float], prim_name) + validator.check_number("learning_rate", learning_rate, 0.0, Rel.GT, prim_name) - validator.check_type("lr_power", lr_power, [float]) - validator.check("lr_power", lr_power, "", 0.0, Rel.LE) + validator.check_value_type("lr_power", lr_power, [float], prim_name) + validator.check_number("lr_power", lr_power, 0.0, Rel.LE, prim_name) - validator.check_type("l1", l1, [float]) - validator.check("l1", l1, "", 0.0, Rel.GE) + validator.check_value_type("l1", l1, [float], prim_name) + validator.check_number("l1", l1, 0.0, Rel.GE, prim_name) - validator.check_type("l2", l2, [float]) - validator.check("l2", l2, "", 0.0, Rel.GE) + validator.check_value_type("l2", l2, [float], prim_name) + validator.check_number("l2", l2, 0.0, Rel.GE, prim_name) - validator.check_type("use_locking", use_locking, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], prim_name) - validator.check_type("loss_scale", loss_scale, [float]) - validator.check("loss_scale", loss_scale, "", 1.0, Rel.GE) + validator.check_value_type("loss_scale", loss_scale, [float], prim_name) + validator.check_number("loss_scale", loss_scale, 1.0, Rel.GE, prim_name) - validator.check_type("weight_decay", weight_decay, [float]) - validator.check("weight_decay", weight_decay, "", 0.0, Rel.GE) + validator.check_value_type("weight_decay", weight_decay, [float], prim_name) + validator.check_number("weight_decay", weight_decay, 0.0, Rel.GE, prim_name) class FTRL(Optimizer): @@ -94,7 +95,8 @@ class FTRL(Optimizer): use_locking=False, loss_scale=1.0, weight_decay=0.0): super(FTRL, self).__init__(learning_rate, params) - _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay) + _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale, weight_decay, + self.cls_name) self.moments = self.parameters.clone(prefix="moments", init=initial_accum) self.linear = self.parameters.clone(prefix="linear", init='zeros') self.l1 = l1 diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index e74d6fc6a8..afcbf8cda4 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -21,7 +21,7 @@ from mindspore.ops import composite as C from mindspore.ops import functional as F from mindspore.common.parameter import Parameter from mindspore.common.tensor import Tensor -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from .optimizer import Optimizer from .. import layer @@ -109,23 +109,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para def _check_param_value(decay_steps, warmup_steps, start_learning_rate, - end_learning_rate, power, beta1, beta2, eps, weight_decay): + end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - validator.check_type("decay_steps", decay_steps, [int]) - validator.check_type("warmup_steps", warmup_steps, [int]) - validator.check_type("start_learning_rate", start_learning_rate, [float]) - validator.check_type("end_learning_rate", end_learning_rate, [float]) - validator.check_type("power", power, [float]) - validator.check_type("beta1", beta1, [float]) - validator.check_type("beta2", beta2, [float]) - validator.check_type("eps", eps, [float]) - validator.check_type("weight_dacay", weight_decay, [float]) - validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT) - validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER) - validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER) - validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_value_type("decay_steps", decay_steps, [int], prim_name) + validator.check_value_type("warmup_steps", warmup_steps, [int], prim_name) + validator.check_value_type("start_learning_rate", start_learning_rate, [float], prim_name) + validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) + validator.check_value_type("power", power, [float], prim_name) + validator.check_value_type("beta1", beta1, [float], prim_name) + validator.check_value_type("beta2", beta2, [float], prim_name) + validator.check_value_type("eps", eps, [float], prim_name) + validator.check_value_type("weight_dacay", weight_decay, [float], prim_name) + validator.check_number_range("decay_steps", decay_steps, 1, float("inf"), Rel.INC_LEFT, prim_name) + validator.check_number_range("beta1", beta1, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("beta2", beta2, 0.0, 1.0, Rel.INC_NEITHER, prim_name) + validator.check_number_range("eps", eps, 0.0, float("inf"), Rel.INC_NEITHER, prim_name) + validator.check_number_range("weight_decay", weight_decay, 0.0, float("inf"), Rel.INC_LEFT, prim_name) class Lamb(Optimizer): @@ -182,7 +182,7 @@ class Lamb(Optimizer): super(Lamb, self).__init__(start_learning_rate, params) _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, - power, beta1, beta2, eps, weight_decay) + power, beta1, beta2, eps, weight_decay, self.cls_name) # turn them to scalar when me support scalar/tensor mix operations self.global_step = Parameter(initializer(0, [1]), name="global_step") diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 6c6d14ed7a..95e6ea7933 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -22,7 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common.tensor import Tensor from mindspore import log as logger @@ -63,7 +63,7 @@ class Optimizer(Cell): self.gather = None self.assignadd = None self.global_step = None - validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT) + validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) else: self.dynamic_lr = True self.gather = P.GatherV2() diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index a68dc6f7c4..97d7538a26 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -14,7 +14,7 @@ # ============================================================================ """rmsprop""" from mindspore.ops import functional as F, composite as C, operations as P -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .optimizer import Optimizer rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") @@ -144,8 +144,8 @@ class RMSProp(Optimizer): self.decay = decay self.epsilon = epsilon - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("centered", centered, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], self.cls_name) + validator.check_value_type("centered", centered, [bool], self.cls_name) self.centered = centered if centered: self.opt = P.ApplyCenteredRMSProp(use_locking) diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index 983be4bf80..db0775e023 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -15,7 +15,7 @@ """sgd""" from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.parameter import Parameter -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator from .optimizer import Optimizer sgd_opt = C.MultitypeFuncGraph("sgd_opt") @@ -100,7 +100,7 @@ class SGD(Optimizer): raise ValueError("dampening should be at least 0.0, but got dampening {}".format(dampening)) self.dampening = dampening - validator.check_type("nesterov", nesterov, [bool]) + validator.check_value_type("nesterov", nesterov, [bool], self.cls_name) self.nesterov = nesterov self.opt = P.SGD(dampening, weight_decay, nesterov) diff --git a/mindspore/ops/op_info_register.py b/mindspore/ops/op_info_register.py index 7dd7a9f729..90b6e1aadd 100644 --- a/mindspore/ops/op_info_register.py +++ b/mindspore/ops/op_info_register.py @@ -19,7 +19,7 @@ import os import json import inspect from mindspore._c_expression import Oplib -from mindspore._checkparam import ParamValidator as validator +from mindspore._checkparam import Validator as validator # path of built-in op info register. BUILT_IN_OPS_REGISTER_PATH = "mindspore/ops/_op_impl" @@ -43,7 +43,7 @@ def op_info_register(op_info): op_info_real = json.dumps(op_info) else: op_info_real = op_info - validator.check_type("op_info", op_info_real, [str]) + validator.check_value_type("op_info", op_info_real, [str], None) op_lib = Oplib() file_path = os.path.realpath(inspect.getfile(func)) # keep the path custom ops implementation. diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index c4c115ef27..66e08874b2 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -16,7 +16,7 @@ from easydict import EasyDict as edict from .. import nn -from .._checkparam import ParamValidator as validator +from .._checkparam import Validator as validator from .._checkparam import Rel from ..common import dtype as mstype from ..nn.wrap.cell_wrapper import _VirtualDatasetCell @@ -73,14 +73,14 @@ def _check_kwargs(key_words): raise ValueError(f"Unsupported arg '{arg}'") if 'cast_model_type' in key_words: - validator.check('cast_model_type', key_words['cast_model_type'], - [mstype.float16, mstype.float32], Rel.IN) + validator.check_type_name('cast_model_type', key_words['cast_model_type'], + [mstype.float16, mstype.float32], None) if 'keep_batchnorm_fp32' in key_words: - validator.check_isinstance('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool) + validator.check_value_type('keep_batchnorm_fp32', key_words['keep_batchnorm_fp32'], bool, None) if 'loss_scale_manager' in key_words: loss_scale_manager = key_words['loss_scale_manager'] if loss_scale_manager: - validator.check_isinstance('loss_scale_manager', loss_scale_manager, LossScaleManager) + validator.check_value_type('loss_scale_manager', loss_scale_manager, LossScaleManager, None) def _add_loss_network(network, loss_fn, cast_model_type): @@ -97,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type): label = _mp_cast_helper(mstype.float32, label) return self._loss_fn(F.cast(out, mstype.float32), label) - validator.check_isinstance('loss_fn', loss_fn, nn.Cell) + validator.check_value_type('loss_fn', loss_fn, nn.Cell, None) if cast_model_type == mstype.float16: network = WithLossCell(network, loss_fn) else: @@ -126,9 +126,9 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If set, overwrite the level setting. """ - validator.check_isinstance('network', network, nn.Cell) - validator.check_isinstance('optimizer', optimizer, nn.Optimizer) - validator.check('level', level, "", ['O0', 'O2'], Rel.IN) + validator.check_value_type('network', network, nn.Cell, None) + validator.check_value_type('optimizer', optimizer, nn.Optimizer, None) + validator.check('level', level, "", ['O0', 'O2'], Rel.IN, None) _check_kwargs(kwargs) config = dict(_config_level[level], **kwargs) config = edict(config) diff --git a/mindspore/train/loss_scale_manager.py b/mindspore/train/loss_scale_manager.py index 5650c58f62..c8c28a72cb 100644 --- a/mindspore/train/loss_scale_manager.py +++ b/mindspore/train/loss_scale_manager.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """Loss scale manager abstract class.""" -from .._checkparam import ParamValidator as validator +from .._checkparam import Validator as validator from .._checkparam import Rel from .. import nn @@ -97,7 +97,7 @@ class DynamicLossScaleManager(LossScaleManager): if init_loss_scale < 1.0: raise ValueError("Loss scale value should be > 1") self.loss_scale = init_loss_scale - validator.check_integer("scale_window", scale_window, 0, Rel.GT) + validator.check_integer("scale_window", scale_window, 0, Rel.GT, self.__class__.__name__) self.scale_window = scale_window if scale_factor <= 0: raise ValueError("Scale factor should be > 1") diff --git a/tests/ut/python/nn/test_dynamic_lr.py b/tests/ut/python/nn/test_dynamic_lr.py index cb959956d6..96f9d5afde 100644 --- a/tests/ut/python/nn/test_dynamic_lr.py +++ b/tests/ut/python/nn/test_dynamic_lr.py @@ -32,7 +32,7 @@ power = 0.5 class TestInputs: def test_milestone1(self): milestone1 = 1 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone1, learning_rates) def test_milestone2(self): @@ -46,12 +46,12 @@ class TestInputs: def test_learning_rates1(self): lr = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone, lr) def test_learning_rates2(self): lr = [1, 2, 1] - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone, lr) def test_learning_rate_type(self): @@ -158,7 +158,7 @@ class TestInputs: def test_is_stair(self): is_stair = 1 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch, is_stair) def test_min_lr_type(self): @@ -183,12 +183,12 @@ class TestInputs: def test_power(self): power1 = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power1) def test_update_decay_epoch(self): update_decay_epoch = 1 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch, power, update_decay_epoch) diff --git a/tests/ut/python/nn/test_psnr.py b/tests/ut/python/nn/test_psnr.py index 5a908b308d..32e7b570aa 100644 --- a/tests/ut/python/nn/test_psnr.py +++ b/tests/ut/python/nn/test_psnr.py @@ -52,7 +52,7 @@ def test_psnr_max_val_negative(): def test_psnr_max_val_bool(): max_val = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): net = PSNRNet(max_val) def test_psnr_max_val_zero(): diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py index a698b59f69..cf946a1617 100644 --- a/tests/ut/python/nn/test_ssim.py +++ b/tests/ut/python/nn/test_ssim.py @@ -51,7 +51,7 @@ def test_ssim_max_val_negative(): def test_ssim_max_val_bool(): max_val = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): net = SSIMNet(max_val) def test_ssim_max_val_zero(): @@ -92,4 +92,4 @@ def test_ssim_k1_k2_wrong_value(): with pytest.raises(ValueError): net = SSIMNet(k2=0.0) with pytest.raises(ValueError): - net = SSIMNet(k2=-1.0) \ No newline at end of file + net = SSIMNet(k2=-1.0) diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 7364893503..d28852ed8e 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -577,14 +577,14 @@ test_cases_for_verify_exception = [ ('MaxPool2d_ValueError_2', { 'block': ( lambda _: nn.MaxPool2d(kernel_size=120, stride=True, pad_mode="valid"), - {'exception': ValueError}, + {'exception': TypeError}, ), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], }), ('MaxPool2d_ValueError_3', { 'block': ( lambda _: nn.MaxPool2d(kernel_size=3, stride=True, pad_mode="valid"), - {'exception': ValueError}, + {'exception': TypeError}, ), 'desc_inputs': [Tensor(np.random.randn(32, 3, 112, 112).astype(np.float32).transpose(0, 3, 1, 2))], }), diff --git a/tests/ut/python/pynative_mode/nn/test_pooling.py b/tests/ut/python/pynative_mode/nn/test_pooling.py index bb1822f8a8..f8df3ada3f 100644 --- a/tests/ut/python/pynative_mode/nn/test_pooling.py +++ b/tests/ut/python/pynative_mode/nn/test_pooling.py @@ -38,7 +38,7 @@ def test_avgpool2d_error_input(): """ test_avgpool2d_error_input """ kernel_size = 5 stride = 2.3 - with pytest.raises(ValueError): + with pytest.raises(TypeError): nn.AvgPool2d(kernel_size, stride) From b1f5e44cd4e386724e65506a580c8786faa826e2 Mon Sep 17 00:00:00 2001 From: Chong Date: Thu, 16 Apr 2020 10:22:39 +0200 Subject: [PATCH 002/142] improve parser --- .../auto_parallel/rec_core/rec_parse_graph.cc | 307 +++++------------- .../auto_parallel/rec_core/rec_parse_graph.h | 37 +-- .../ccsrc/parallel/step_auto_parallel.cc | 4 +- 3 files changed, 88 insertions(+), 260 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 3ff3473298..44d3642b9c 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -39,304 +39,149 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { return tensor; } -bool IsInList(const std::string& name, const std::vector& list) { - return std::find(list.begin(), list.end(), name) != list.end(); -} - Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops) { Graph::NodeType NewOp; - NewOp.name = ops[iter_ops]->cnode_name(); + NewOp.name = ops[iter_ops]->name(); NewOp.info = InfoType::kApplication; auto op_type = ops[iter_ops]->type(); auto idx = DictOpType.find(op_type); if (idx == DictOpType.end()) { NewOp.apply.op_type = OperatorType::kRecUnkownType; - MS_LOG(INFO) << "Unknown type in rec_parse_graph::MakeNewOperator"; + MS_LOG(INFO) << "Unknown operator type."; } else { NewOp.apply.op_type = DictOpType.at(op_type); } - if ((NewOp.apply.op_type == OperatorType::kRecMatMul) || (NewOp.apply.op_type == OperatorType::kRecBiasAdd) || - (NewOp.apply.op_type == OperatorType::kRecReshape)) { - NewOp.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } else if ((NewOp.apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) || - (NewOp.apply.op_type == OperatorType::kRecUnkownType)) { - NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); - } else { + if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 4) { NewOp.tensor_parm = MakeTensor( ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 2) { + NewOp.tensor_parm = Fill2DTensor(ops, iter_ops, NewOp); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 1) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0]); + } else if (ops[iter_ops]->outputs_tensor_info()[0].shape().size() == 0) { + NewOp.tensor_parm = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; } + NewOp.apply = CompleteOperatorInputs(ops, iter_ops, NewOp); return NewOp; } -Graph::NodeType MakeNewTensor(std::vector> ops, const size_t iter_ops, - const std::string& input, const size_t iter_input_tensors, std::shared_ptr graph, - size_t current_op_index) { - Graph::NodeType NewTensor; - NewTensor.name = input; - NewTensor.info = InfoType::kConstant; - - if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { - NewTensor.tensor_parm = MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { - Fill2DTensor(ops, iter_ops, graph, iter_input_tensors, current_op_index, NewTensor); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { - NewTensor.tensor_parm = MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { - NewTensor.tensor_parm = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape unknown in rec_parse_graph::MakeNewTensor"; - } - return NewTensor; -} - -void Fill2DTensor(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, const size_t current_op_index, - Graph::NodeType NewTensor) { - if (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecMatMul) { +TensorParam Fill2DTensor(const std::vector>& ops, const size_t iter_ops, + Graph::NodeType NewTensor) { + if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); - if (transpose_a && (iter_input_tensors == 0)) { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (transpose_b && (iter_input_tensors == 1)) { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + if (transpose_a) { + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[1], + ops[iter_ops]->inputs_tensor_info()[0].shape()[0]); + } else if (transpose_b) { + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[1], + ops[iter_ops]->inputs_tensor_info()[0].shape()[0]); } else { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[0], + ops[iter_ops]->inputs_tensor_info()[0].shape()[1]); } } else { - NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); + NewTensor.tensor_parm = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[0].shape()[0], + ops[iter_ops]->inputs_tensor_info()[0].shape()[1]); } + return NewTensor.tensor_parm; } -void CompleteOperatorInputs(std::vector> ops, size_t iter_ops, size_t iter_input_tensors, - size_t current_op_index, std::shared_ptr graph) { - if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = - MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], - ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { - Complete2DInputs(ops, iter_ops, graph, iter_input_tensors, current_op_index); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = - MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); - } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); - } else { - MS_LOG(ERROR) << "Tensor's shape unknown in rec_parse_graph::MakeNewTensor"; +OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, + Graph::NodeType NewTensor) { + for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); + iter_input_tensors++) { + if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 4) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[2], + ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[3]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 2) { + NewTensor.apply.arguments[iter_input_tensors] = Complete2DInputs(ops, iter_ops, iter_input_tensors, NewTensor); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 1) { + NewTensor.apply.arguments[iter_input_tensors] = + MakeTensor(1, 1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); + } else if (ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape().size() == 0) { + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, 1, 1); + } else { + MS_LOG(ERROR) << "Tensor's shape is unknown."; + } } + return NewTensor.apply; } -void Complete2DInputs(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, - const size_t current_op_index) { - if (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecMatMul) { +TensorParam Complete2DInputs(const std::vector>& ops, const size_t iter_ops, + const size_t iter_input_tensors, Graph::NodeType NewTensor) { + if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); if (transpose_a && (iter_input_tensors == 0)) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); } else if (transpose_b && (iter_input_tensors == 1)) { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0]); } else { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); } } else { - graph->nodes[current_op_index].apply.arguments[iter_input_tensors] = + NewTensor.apply.arguments[iter_input_tensors] = MakeTensor(1, 1, ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[0], ops[iter_ops]->inputs_tensor_info()[iter_input_tensors].shape()[1]); } -} - -void MakeEdge(std::shared_ptr graph, const size_t input_index, const size_t current_op_index) { - graph->nodes[input_index].node_out.push_back(current_op_index); - graph->nodes[current_op_index].node_in.push_back(input_index); -} - -void ModifyTensorToOperator(std::shared_ptr graph, const size_t current_op_index, const size_t iter_ops, - std::vector> ops) { - graph->nodes[current_op_index].info = InfoType::kApplication; - std::string op_type = ops[iter_ops]->type(); - auto idx = DictOpType.find(op_type); - if (idx == DictOpType.end()) { - graph->nodes[current_op_index].apply.op_type = OperatorType::kRecUnkownType; - MS_LOG(INFO) << "Unknown type in rec_parse_graph::ModifyTensorToOperator"; - } else { - graph->nodes[current_op_index].apply.op_type = DictOpType.at(op_type); - } - - if ((graph->nodes[current_op_index].apply.op_type == OperatorType::kRecMatMul) || - (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecBiasAdd) || - (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecReshape)) { - graph->nodes[current_op_index].tensor_parm = MakeTensor(1, 1, ops[iter_ops]->outputs_tensor_info()[0].shape()[0], - ops[iter_ops]->outputs_tensor_info()[0].shape()[1]); - } else if ((graph->nodes[current_op_index].apply.op_type == OperatorType::kRecSparseSoftmaxCrossEntropyWithLogits) || - (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecUnkownType)) { - graph->nodes[current_op_index].tensor_parm = MakeTensor(1, 1, 1, 1); - } else { - graph->nodes[current_op_index].tensor_parm = MakeTensor( - ops[iter_ops]->outputs_tensor_info()[0].shape()[0], ops[iter_ops]->outputs_tensor_info()[0].shape()[1], - ops[iter_ops]->outputs_tensor_info()[0].shape()[2], ops[iter_ops]->outputs_tensor_info()[0].shape()[3]); - } + return NewTensor.apply.arguments[iter_input_tensors]; } std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names, - const std::shared_ptr>& ops_nodes_list) { - std::vector current_graph; + const std::vector>& input_tensor_names) { std::shared_ptr graph(new Graph); if (ops.size() > SIZE_MAX / 2) { MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; } - for (size_t iter_ops = ops.size(); iter_ops > 0; iter_ops--) { - if (IsInList(ops[iter_ops - 1]->cnode_name(), current_graph)) { - size_t current_op_index = static_cast(std::distance( - current_graph.begin(), std::find(current_graph.begin(), current_graph.end(), ops[iter_ops]->cnode_name()))); - std::vector::iterator itr = ops_nodes_list->insert(ops_nodes_list->begin(), current_op_index); - if (itr != ops_nodes_list->begin()) { - MS_LOG(EXCEPTION) << "Iterator error."; - } - ModifyTensorToOperator(graph, current_op_index, iter_ops - 1, ops); - LinkOps(graph, ops, input_tensor_names, current_graph, iter_ops - 1, current_op_index); - } else { - Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops - 1); - current_graph.push_back(NewOp.name); - graph->nodes.push_back(NewOp); - size_t current_op_index = graph->nodes.size() - 1; - std::vector::iterator itr = ops_nodes_list->insert(ops_nodes_list->begin(), current_op_index); - if (itr != ops_nodes_list->begin()) { - MS_LOG(EXCEPTION) << "Iterator error."; - } - LinkOps(graph, ops, input_tensor_names, current_graph, iter_ops - 1, current_op_index); - } + for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { + Graph::NodeType NewOp = MakeNewOperator(ops, iter_ops); + graph->nodes.push_back(NewOp); } - return graph; -} - -void LinkOps(std::shared_ptr graph, std::vector> ops, - const std::vector>& input_tensor_names, std::vector current_graph, - const size_t iter_ops, const size_t current_op_index) { - for (size_t iter_input_tensors = 0; - iter_input_tensors < std::min(input_tensor_names[iter_ops].size(), ops[iter_ops]->inputs_tensor_info().size()); - iter_input_tensors++) { - std::string input = input_tensor_names[iter_ops][iter_input_tensors]; - if (IsInList(input, current_graph)) { - size_t input_index = static_cast( - std::distance(current_graph.begin(), std::find(current_graph.begin(), current_graph.end(), input))); - MakeEdge(graph, input_index, current_op_index); - CompleteOperatorInputs(ops, iter_ops, iter_input_tensors, current_op_index, graph); - } else { - Graph::NodeType NewTensor = MakeNewTensor(ops, iter_ops, input, iter_input_tensors, graph, current_op_index); - current_graph.push_back(NewTensor.name); - graph->nodes.push_back(NewTensor); - size_t input_index = graph->nodes.size() - 1; - CompleteOperatorInputs(ops, iter_ops, iter_input_tensors, current_op_index, graph); - MakeEdge(graph, input_index, current_op_index); - } + MakeEdge(input_tensor_names, graph); - if (graph->nodes[current_op_index].apply.op_type == OperatorType::kRecBatchNorm) { - break; - } - } + return graph; } -void Eliminate_Aux(const size_t node_index, std::shared_ptr graph, - const std::shared_ptr>> eli_list) { - if ((graph->nodes[node_index].apply.op_type == OperatorType::kRecUnkownType) || - (graph->nodes[node_index].apply.op_type == OperatorType::kRecReLU)) { - size_t input_index = (graph->nodes[node_index].node_in)[0]; - std::vector outputs = graph->nodes[node_index].node_out; - - std::vector eli; - eli.push_back(node_index); - eli.push_back(input_index); - for (size_t i = 0; i < outputs.size(); i++) { - eli.push_back(i); - } - eli_list->push_back(eli); - - for (size_t i = 1; i < (size_t)graph->nodes[node_index].node_in.size(); i++) { - std::vector tmp; - tmp.push_back(node_index); - tmp.push_back((graph->nodes[node_index].node_in)[i]); - eli_list->push_back(tmp); - } - - auto it = find(graph->nodes[input_index].node_out.begin(), graph->nodes[input_index].node_out.end(), node_index); - std::vector::iterator itr = graph->nodes[input_index].node_out.erase(it); - if (itr != it) { - MS_LOG(EXCEPTION) << "Iterator error."; - } - for (auto output : outputs) { - graph->nodes[input_index].node_out.push_back(output); - } - for (auto& output_index : outputs) { - auto itt = find(graph->nodes[output_index].node_in.begin(), graph->nodes[output_index].node_in.end(), node_index); - graph->nodes[output_index] - .node_in[static_cast(std::distance(graph->nodes[output_index].node_in.begin(), itt))] = input_index; +void MakeEdge(const std::vector>& input_tensor_names, std::shared_ptr graph) { + for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { + for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { + size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); + if (head_node_index < SIZE_MAX / 2 && head_node_index != iter_i) { + graph->nodes[iter_i].node_in.push_back(head_node_index); + graph->nodes[head_node_index].node_out.push_back(iter_i); + } } } } -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - std::shared_ptr>> eli_list, - std::shared_ptr> index_list) { - for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { - if (graph->nodes[node_index].info == InfoType::kApplication) { - Eliminate_Aux(node_index, graph, eli_list); - } - } - - index_list->reserve(graph->nodes.size()); - for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { - index_list->push_back(i); - } - - for (size_t i = 0; i < (size_t)eli_list->size(); i++) { - index_list->at((eli_list->at(i)[0])) = SIZE_MAX; - for (size_t j = eli_list->at(i)[0] + 1; j < (size_t)index_list->size(); j++) { - index_list->at(j)--; +size_t GetIndexInInputTensorNames(const std::vector>& input_tensor_name, + const std::string& input_name) { + for (size_t index = 0; index < input_tensor_name.size(); index++) { + if (input_tensor_name[index][0] == input_name) { + return index; } } - - std::shared_ptr new_graph(new Graph); - for (size_t i = 0; i < (size_t)(graph->nodes.size() - eli_list->size()); i++) { - Graph::NodeType NewOp; - new_graph->nodes.push_back(NewOp); - } - - for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) { - if (index_list->at(i) > SIZE_MAX / 2) continue; - new_graph->nodes[index_list->at(i)] = graph->nodes[i]; - for (size_t j = 0; j < (size_t)new_graph->nodes[index_list->at(i)].node_in.size(); j++) { - new_graph->nodes[index_list->at(i)].node_in[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_in[j]); - } - for (size_t j = 0; j < (size_t)new_graph->nodes[index_list->at(i)].node_out.size(); j++) { - new_graph->nodes[index_list->at(i)].node_out[j] = index_list->at(new_graph->nodes[index_list->at(i)].node_out[j]); - } - } - - return new_graph; + MS_LOG(INFO) << "Get index failed, using SIZE_MAX insted"; + return SIZE_MAX; } } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 7dfca86a21..0d719c33d8 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -43,39 +43,24 @@ const std::map DictOpType{ const TensorParam MakeTensor(int n, int c, int h, int w); -bool IsInList(const std::string& name, const std::vector& list); - Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops); -Graph::NodeType MakeNewTensor(std::vector> ops, const size_t iter_ops, - const std::string& input, const size_t iter_input_tensors, std::shared_ptr graph, - size_t current_op_index); -void Fill2DTensor(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, const size_t current_op_index, - Graph::NodeType NewTensor); -void CompleteOperatorInputs(std::vector> ops, size_t iter_ops, size_t iter_input_tensors, - size_t current_op_index, std::shared_ptr graph); -void Complete2DInputs(const std::vector>& ops, const size_t iter_ops, - const std::shared_ptr graph, const size_t iter_input_tensors, - const size_t current_op_index); -void MakeEdge(std::shared_ptr graph, const size_t input_index, const size_t current_op_index); +TensorParam Fill2DTensor(const std::vector>& ops, const size_t iter_ops, + Graph::NodeType NewTensor); + +OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, + Graph::NodeType NewTensor); -void ModifyTensorToOperator(std::shared_ptr graph, const size_t current_op_index, const size_t iter_ops, - std::vector> ops); +TensorParam Complete2DInputs(const std::vector>& ops, const size_t iter_ops, + const size_t iter_input_tensor, Graph::NodeType NewTensor); std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names, - const std::shared_ptr>& ops_nodes_list); + const std::vector>& input_tensor_names); -void LinkOps(std::shared_ptr graph, std::vector> ops, - const std::vector>& input_tensor_names, std::vector current_graph, - const size_t iter_ops, const size_t current_op_index); +void MakeEdge(const std::vector>& input_tensor_names, std::shared_ptr graph); -std::shared_ptr EliminateGraph(const std::shared_ptr graph, - std::shared_ptr>> eli_list, - std::shared_ptr> index_list); -void Eliminate_Aux(const size_t node_index, std::shared_ptr graph, - const std::shared_ptr>> eli_list); +size_t GetIndexInInputTensorNames(const std::vector>& input_tensor_names, + const std::string& input_name); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 1d52eac82d..d4822b8309 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -461,7 +461,6 @@ Status ConstructCostGraphNodes(const std::vector &all_nodes, const F // Needed by rec_parser operator_info->set_type(prim->name()); std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); - operator_info->set_cnode_name(cnode->ToString()); entire_costgraph->AddOperator(operator_info); (void)cnode->set_operator_info(operator_info); @@ -934,9 +933,8 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const std::shared_ptr> index_list(new std::vector); std::shared_ptr>> eli_list(new std::vector>); - std::shared_ptr graph = ParseGraph(ops, input_tensor_names, ops_nodes_list); + std::shared_ptr graph = ParseGraph(ops, input_tensor_names); - graph = EliminateGraph(graph, eli_list, index_list); size_t num_device = g_device_manager->DeviceNum(); if (PartitionForAllDevices(num_device, graph) == SUCCESS) { MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; From a91f82d79f158ffd04bf1f06bdb60d23a59f3475 Mon Sep 17 00:00:00 2001 From: candanzg Date: Sat, 18 Apr 2020 20:32:00 +0800 Subject: [PATCH 003/142] Supplement summary log Signed-off-by: candanzg --- mindspore/train/summary/summary_record.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mindspore/train/summary/summary_record.py b/mindspore/train/summary/summary_record.py index d96ac4773a..3dbe31f0e4 100644 --- a/mindspore/train/summary/summary_record.py +++ b/mindspore/train/summary/summary_record.py @@ -46,10 +46,14 @@ def _cache_summary_tensor_data(summary): class SummaryRecord: """ - Summary log record. - SummaryRecord is used to record the summary value. - The API will create an event file in a given directory and add summaries and events to it. + + Note: + The API will create an event file in a given directory and add summaries and events to it. + It writes the event log to a file by executing the record method. In addition, + if the SummaryRecord object is created and the summary operator is used in the network, + even if the record method is not called, the event in the cache will be written to the + file at the end of execution or when the summary is closed. Args: log_dir (str): The log_dir is a directory location to save the summary. From 79c4312e892f99e19672e9a04e6a11d6a98b7c9b Mon Sep 17 00:00:00 2001 From: candanzg Date: Sat, 18 Apr 2020 16:33:30 +0800 Subject: [PATCH 004/142] NotEqual op auto cast Signed-off-by: candanzg --- tests/ut/python/ops/test_ops.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 078ada8406..eee5080a6c 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -401,6 +401,11 @@ test_case_math_ops = [ 'block': P.NotEqual(), 'desc_inputs': [[4, 1], [2, 3, 4, 5]], 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))]}), + ('NotEqual_0', { + 'block': P.NotEqual(), + 'desc_inputs': [ 1, [2, 3, 4, 5]], + 'desc_bprop': [Tensor(np.ones((2, 3, 4, 5), np.bool_))], + 'skip': ['backward']}), ('Greater', { 'block': P.Greater(), 'desc_inputs': [[2, 3, 4, 1], [4, 5]], From de7457c9f58972e3f8a1307dfbd2a2338ea0af67 Mon Sep 17 00:00:00 2001 From: candanzg Date: Sun, 19 Apr 2020 11:46:47 +0800 Subject: [PATCH 005/142] fixed bug for makedirs in python Signed-off-by: candanzg --- mindspore/train/_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/train/_utils.py b/mindspore/train/_utils.py index 85b7629002..7bc07b126e 100644 --- a/mindspore/train/_utils.py +++ b/mindspore/train/_utils.py @@ -87,7 +87,7 @@ def _make_directory(path: str): # All exceptions need to be caught because create directory maybe have some limit(permissions) logger.debug("The directory(%s) doesn't exist, will create it", path) try: - os.makedirs(path) + os.makedirs(path, exist_ok=True) real_path = path except PermissionError as e: logger.error("No write permission on the directory(%r), error = %r", path, e) From 43a2e998331a79b969540deeaeb6b227ae95bb1f Mon Sep 17 00:00:00 2001 From: Junhan Hu Date: Thu, 16 Apr 2020 13:42:16 -0400 Subject: [PATCH 006/142] Add python sampler support for CPP dataset --- .../ccsrc/dataset/api/python_bindings.cc | 5 + .../datasetops/source/sampler/CMakeLists.txt | 1 + .../source/sampler/python_sampler.cc | 83 +++++++++++++++++ .../source/sampler/python_sampler.h | 58 ++++++++++++ .../datasetops/source/sampler/sampler.cc | 3 - .../source/sampler/sequential_sampler.cc | 1 + mindspore/dataset/__init__.py | 2 +- mindspore/dataset/engine/datasets.py | 2 +- mindspore/dataset/engine/samplers.py | 92 +++++++++++++++++-- mindspore/dataset/engine/validators.py | 8 +- tests/ut/python/dataset/test_sampler.py | 57 ++++++++++++ 11 files changed, 296 insertions(+), 16 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 076f2ecc36..6bacd67396 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -53,6 +53,7 @@ #include "dataset/engine/datasetops/source/sampler/sequential_sampler.h" #include "dataset/engine/datasetops/source/sampler/subset_random_sampler.h" #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" +#include "dataset/engine/datasetops/source/sampler/python_sampler.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/jagged_connector.h" #include "dataset/kernels/data/to_float16_op.h" @@ -415,6 +416,7 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "SequentialSampler") .def(py::init<>()); + (void)py::class_>(*m, "SubsetRandomSampler") .def(py::init>(), py::arg("indices")); @@ -425,6 +427,9 @@ void bindSamplerOps(py::module *m) { (void)py::class_>(*m, "WeightedRandomSampler") .def(py::init, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), py::arg("replacement")); + + (void)py::class_>(*m, "PythonSampler") + .def(py::init(), py::arg("pySampler")); } void bindInfoObjects(py::module *m) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt index 5d55c8276a..b084e1c125 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(engine-datasetops-source-sampler OBJECT distributed_sampler.cc pk_sampler.cc + python_sampler.cc random_sampler.cc sampler.cc sequential_sampler.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc new file mode 100644 index 0000000000..464717feb4 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -0,0 +1,83 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/source/sampler/python_sampler.h" + +#include + +namespace mindspore { +namespace dataset { + +PythonSampler::PythonSampler(py::object py_sampler_instance, int64_t samples_per_buffer) + : Sampler(samples_per_buffer), py_sampler_instance(py_sampler_instance), need_to_reset_(false) {} + +Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { + if (need_to_reset_) { + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagEOE); + } else { + std::shared_ptr sample_ids; + { + py::gil_scoped_acquire gil_acquire; + (*out_buffer) = std::make_unique(0, DataBuffer::kDeBFlagNone); + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py::object py_ret = py_sampler_instance.attr("_get_indices")(); + py::array np_sample_ids = py_ret.cast(); + Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + TensorRow row(1, sample_ids); + (*out_buffer)->set_tensor_table(std::make_unique(1, row)); + need_to_reset_ = true; + } + return Status::OK(); +} + +Status PythonSampler::InitSampler() { + CHECK_FAIL_RETURN_UNEXPECTED(num_rows_ > 0, "ERROR num_rows_ should be greater than 0"); + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("_handshake")(num_rows_, num_samples_); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +Status PythonSampler::Reset() { + CHECK_FAIL_RETURN_UNEXPECTED(need_to_reset_, "ERROR Reset() called not at end of an epoch"); + need_to_reset_ = false; + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + py_sampler_instance.attr("reset")(); + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h new file mode 100644 index 0000000000..b8734fee6a --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.h @@ -0,0 +1,58 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ + +#include +#include + +#include "dataset/engine/datasetops/source/sampler/sampler.h" + +namespace mindspore { +namespace dataset { +class PythonSampler : public Sampler { + public: + // Constructor + // @param int64_t samplesPerBuffer - Num of Sampler Ids to fetch via 1 GetNextBuffer call + explicit PythonSampler(py::object py_sampler_instance, + int64_t samples_per_buffer = std::numeric_limits::max()); + + // Destructor. + ~PythonSampler() = default; + + // Initialize the sampler. + // @return Status + Status InitSampler() override; + + // for next epoch of sampleIds + // @return - The error code return + Status Reset() override; + + // Op calls this to get next Buffer that contains all the sampleIds + // @param std::unique_ptr pBuffer - Buffer to be returned to StorageOp + // @param int32_t workerId - not meant to be used + // @return - The error code return + Status GetNextBuffer(std::unique_ptr *out_buffer) override; + + private: + bool need_to_reset_; // Whether Reset() should be called before calling GetNextBuffer() + + py::object py_sampler_instance; // The handle to the py_sampler python object +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_SAMPLER_PYTHON_SAMPLER_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc index 3c3f5f48e8..9fe752448a 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sampler.cc @@ -48,9 +48,6 @@ Status Sampler::GetAllIdsThenReset(py::array *data) { std::unique_ptr db; std::shared_ptr sample_ids; - // check samples_per_buffer is properly set and doesn't overflow - CHECK_FAIL_RETURN_UNEXPECTED(samples_per_buffer_ + 1 > 1, "samples_per_buffer invalid"); - // A call to derived class to get sample ids wrapped inside a buffer RETURN_IF_NOT_OK(GetNextBuffer(&db)); // Get the only tensor inside the buffer that contains the actual SampleIds for the entire epoch diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc index a3c4fe2256..6ed06b527f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/sequential_sampler.cc @@ -42,6 +42,7 @@ Status SequentialSampler::GetNextBuffer(std::unique_ptr *out_buffer) } Status SequentialSampler::InitSampler() { + num_samples_ = (num_samples_ <= 0) ? num_rows_ : num_samples_; // if num_samples < 0, try if num_rows is set CHECK_FAIL_RETURN_UNEXPECTED(num_samples_ > 0 && samples_per_buffer_ > 0, "Fail to init Sequential Sampler"); samples_per_buffer_ = samples_per_buffer_ > num_samples_ ? num_samples_ : samples_per_buffer_; return Status::OK(); diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 479c66045f..bff23b7abf 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -23,7 +23,7 @@ from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDataset GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \ Shuffle, zip from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ - WeightedRandomSampler + WeightedRandomSampler, Sampler from .engine.serializer_deserializer import serialize, deserialize, show __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8de56a6dff..71df50ac4a 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2032,7 +2032,7 @@ class GeneratorDataset(SourceDataset): if self.sampler is not None and hasattr(source, "__getitem__"): if isinstance(self.sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler)): + samplers.WeightedRandomSampler, samplers.Sampler)): if num_samples is None: num_samples = len(source) sampler_instance = self.sampler.create() diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 0bba559210..421a03ab8d 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -16,11 +16,90 @@ Sampler module provides several samplers to generate sampling data from dataset. There are following samplers: DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, WeightedRandomSampler. +User can also define custom sampler by extending from Sampler class. """ import mindspore._c_dataengine as cde +import numpy as np -class DistributedSampler(): + +class Sampler: + """ + Base class for user defined sampler. + User defined sampler can be used with any existing dataset with sampler support. + + An required _iter_() method should by overridden by user for sample index generation. + An optional reset() method can be overridden for per repeat reset, + + dataset_size and num_samples will be set by dataset once a dataset iterator is created. + + Examples: + >>> import mindspore.dataset as ds + >>> + >>> class ReverseSampler(ds,Sampler): + >>> def __iter__(self): + >>> for i in range(self.dataset_size - 1, -1, -1): + >>> yield i + >>> + >>> ds = ds.ImageFolderDatasetV2(path, sampler=ReverseSampler()) + """ + + def __init__(self): + self.dataset_size = 0 + self.num_samples = 0 + + def __iter__(self): + """ + User defined iterator, must be overridden. + _handshake is guaranteed to be called prior to iterator construction + + """ + raise NotImplementedError + + def reset(self): + """ + Per repeat reset callback, override this method if necessary + """ + + # Initialization handshake callback + # Do not override this method! + def _handshake(self, ds_size, num_samples): + self.dataset_size = ds_size + self.num_samples = num_samples + + # Indices fetcher + # Do not override this method! + def _get_indices(self): + sampler_iter = iter(self) + ret = [] + for _ in range(self.num_samples): + try: + idx = next(sampler_iter) + ret.append(idx) + except StopIteration: + break + return np.array(ret) + + # Instance fetcher + # Do not override this method! + def create(self): + return cde.PythonSampler(self) + + +class BuiltinSampler: + """ + Base class for BuiltinSampler. + + User should not extend this class. + """ + def __init__(self): + pass + + def create(self): + pass + + +class DistributedSampler(BuiltinSampler): """ Sampler that access a shard of the dataset. @@ -65,7 +144,7 @@ class DistributedSampler(): return cde.DistributedSampler(self.num_shards, self.shard_id, self.shuffle, self.seed) -class PKSampler(): +class PKSampler(BuiltinSampler): """ Samples K elements for each P class in the dataset. @@ -106,7 +185,7 @@ class PKSampler(): return cde.PKSampler(self.num_val, self.shuffle) -class RandomSampler(): +class RandomSampler(BuiltinSampler): """ Samples the elements randomly. @@ -147,7 +226,7 @@ class RandomSampler(): return cde.RandomSampler(self.replacement, self.num_samples) -class SequentialSampler(): +class SequentialSampler(BuiltinSampler): """ Samples the dataset elements sequentially, same as not having a sampler. @@ -165,7 +244,7 @@ class SequentialSampler(): return cde.SequentialSampler() -class SubsetRandomSampler(): +class SubsetRandomSampler(BuiltinSampler): """ Samples the elements randomly from a sequence of indices. @@ -196,7 +275,8 @@ class SubsetRandomSampler(): def _create_for_minddataset(self): return cde.MindrecordSubsetRandomSampler(self.indices) -class WeightedRandomSampler(): + +class WeightedRandomSampler(BuiltinSampler): """ Samples the elements from [0, len(weights) - 1] randomly with the given weights (probabilities). diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b74e913202..ff56652bcb 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -297,9 +297,7 @@ def check_sampler_shuffle_shard_options(param_dict): shuffle, sampler = param_dict.get('shuffle'), param_dict.get('sampler') num_shards, shard_id = param_dict.get('num_shards'), param_dict.get('shard_id') - if sampler is not None and not isinstance(sampler, ( - samplers.DistributedSampler, samplers.PKSampler, samplers.RandomSampler, samplers.SequentialSampler, - samplers.SubsetRandomSampler, samplers.WeightedRandomSampler)): + if sampler is not None and not isinstance(sampler, (samplers.BuiltinSampler, samplers.Sampler)): raise ValueError("sampler is not a valid Sampler type.") if sampler is not None: @@ -579,11 +577,11 @@ def check_generatordataset(method): raise ValueError("PKSampler is not supported by GeneratorDataset") if not isinstance(sampler, (samplers.SequentialSampler, samplers.DistributedSampler, samplers.RandomSampler, samplers.SubsetRandomSampler, - samplers.WeightedRandomSampler)): + samplers.WeightedRandomSampler, samplers.Sampler)): try: iter(sampler) except TypeError: - raise TypeError("sampler should be either iterable or from dataset.samplers.py") + raise TypeError("sampler should be either iterable or from mindspore.dataset.samplers") return method(*args, **kwargs) diff --git a/tests/ut/python/dataset/test_sampler.py b/tests/ut/python/dataset/test_sampler.py index 7a58249f9c..4efca6f818 100644 --- a/tests/ut/python/dataset/test_sampler.py +++ b/tests/ut/python/dataset/test_sampler.py @@ -14,6 +14,7 @@ # ============================================================================== import mindspore.dataset as ds from mindspore import log as logger +import numpy as np # test5trainimgs.json contains 5 images whose un-decoded shape is [83554, 54214, 65512, 54214, 64631] @@ -107,8 +108,64 @@ def test_sampler_py_api(): sampler.get_indices() +def test_python_sampler(): + manifest_file = "../data/dataset/testManifestData/test5trainimgs.json" + map = {(172876, 0): 0, (54214, 0): 1, (54214, 1): 2, (173673, 0): 3, (64631, 1): 4} + + class Sp1(ds.Sampler): + def __iter__(self): + return iter([i for i in range(self.dataset_size)]) + + class Sp2(ds.Sampler): + def __init__(self): + super(Sp2, self).__init__() + # at this stage, self.dataset_size and self.num_samples are not yet known + self.cnt = 0 + + def __iter__(self): # first epoch, all 0, second epoch all 1, third all 2 etc.. ... + return iter([self.cnt for i in range(self.num_samples)]) + + def reset(self): + self.cnt = (self.cnt + 1) % self.dataset_size + + def test_config(num_samples, num_repeats, sampler): + data1 = ds.ManifestDataset(manifest_file, num_samples=num_samples, sampler=sampler) + if num_repeats is not None: + data1 = data1.repeat(num_repeats) + res = [] + for item in data1.create_dict_iterator(): + logger.info("item[image].shape[0]: {}, item[label].item(): {}" + .format(item["image"].shape[0], item["label"].item())) + res.append(map[(item["image"].shape[0], item["label"].item())]) + # print(res) + return res + + def test_generator(): + class MySampler(ds.Sampler): + def __iter__(self): + for i in range(99, -1, -1): + yield i + + data1 = ds.GeneratorDataset([(np.array(i),) for i in range(100)], ["data"], sampler = MySampler()) + i = 99 + for data in data1: + assert data[0] == (np.array(i),) + i = i - 1 + + assert test_config(5, 2, Sp1()) == [0, 1, 2, 3, 4, 0, 1, 2, 3, 4] + assert test_config(2, 6, Sp2()) == [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 0, 0] + test_generator() + + sp1 = Sp1().create() + sp1.set_num_rows(5) + sp1.set_num_samples(5) + sp1.initialize() + assert list(sp1.get_indices()) == [0, 1, 2, 3, 4] + + if __name__ == '__main__': test_sequential_sampler(True) test_random_sampler(True) test_random_sampler_multi_iter(True) test_sampler_py_api() + test_python_sampler() \ No newline at end of file From d0ae610832c71821acadd4fa5d3baa7d0bba0ace Mon Sep 17 00:00:00 2001 From: Junhan Hu Date: Sun, 19 Apr 2020 20:35:32 -0400 Subject: [PATCH 007/142] Review --- .../engine/datasetops/source/sampler/python_sampler.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc index 464717feb4..1747040141 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/sampler/python_sampler.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,8 @@ Status PythonSampler::GetNextBuffer(std::unique_ptr *out_buffer) { Tensor::CreateTensor(&sample_ids, np_sample_ids); // copy numpy to tensor } catch (const py::error_already_set &e) { return Status(StatusCode::kPyFuncException, e.what()); + } catch (const py::cast_error &e) { + return Status(StatusCode::kPyFuncException, "Python Sampler iterator should return integer index"); } } TensorRow row(1, sample_ids); From 3f087dba1ab5063e142b48db480b4eea0c5cb49c Mon Sep 17 00:00:00 2001 From: candanzg Date: Fri, 17 Apr 2020 14:23:35 +0800 Subject: [PATCH 008/142] Tensor assign syntax: 1) A[B]=U 2) A[A>n]=U A.shape == B.shape U is a scalar or Tensor(size==1) B is Tensor(dtype=bool) n is a Number Signed-off-by: candanzg --- mindspore/_extends/parse/resources.py | 2 +- .../ops/composite/multitype_ops/__init__.py | 2 + .../multitype_ops/_multitype_ops_util.py | 45 ++++ .../composite/multitype_ops/setitem_impl.py | 194 ++++++++++++++++++ mindspore/ops/functional.py | 5 + tests/ut/python/ops/test_tensor_slice.py | 96 +++++++++ 6 files changed, 343 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/composite/multitype_ops/_multitype_ops_util.py create mode 100644 mindspore/ops/composite/multitype_ops/setitem_impl.py diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index 9fb357597e..c2c2716697 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -83,6 +83,7 @@ convert_object_map = { T.mul: multitype_ops.mul, T.truediv: multitype_ops.div, T.getitem: multitype_ops.getitem, + T.setitem: multitype_ops.setitem, T.floordiv: multitype_ops.floordiv, T.mod: multitype_ops.mod, T.pow: multitype_ops.pow_, @@ -118,7 +119,6 @@ convert_object_map = { T.iter: M.ms_iter, T.next: M.ms_next, T.hasnext: M.hasnext, - T.setitem: M.setitem, T.make_tuple: F.make_tuple, T.make_dict: F.make_dict, diff --git a/mindspore/ops/composite/multitype_ops/__init__.py b/mindspore/ops/composite/multitype_ops/__init__.py index 40bf71d49a..b7f4f671b8 100644 --- a/mindspore/ops/composite/multitype_ops/__init__.py +++ b/mindspore/ops/composite/multitype_ops/__init__.py @@ -23,6 +23,7 @@ from .pow_impl import pow_ from .floordiv_impl import floordiv from .mod_impl import mod from .getitem_impl import getitem +from .setitem_impl import setitem from .zeros_like_impl import zeros_like from .ones_like_impl import ones_like from .equal_impl import equal @@ -55,6 +56,7 @@ __all__ = [ 'greater_equal', 'negative', 'getitem', + 'setitem', 'logical_and', 'logical_or', 'logical_not' diff --git a/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py new file mode 100644 index 0000000000..b3687c553c --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/_multitype_ops_util.py @@ -0,0 +1,45 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""constexpr util""" + +from ...primitive import constexpr + + +@constexpr +def is_same_type(inst, type_): + """ + Check whether an object is an instance of a target type. + + Inputs: + inst (mindspore.dtype): Inspected type. + type_ (mindspore.dtype): Target type. + + Outputs: + bool, the check result. + """ + return inst == type_ + + +@constexpr +def error_msg(msg="", format_values=""): + """ + Used to throw exception information. + + Inputs: + msg (str): information content. + """ + + raise ValueError(msg.format(*format_values)) diff --git a/mindspore/ops/composite/multitype_ops/setitem_impl.py b/mindspore/ops/composite/multitype_ops/setitem_impl.py new file mode 100644 index 0000000000..31c96932c5 --- /dev/null +++ b/mindspore/ops/composite/multitype_ops/setitem_impl.py @@ -0,0 +1,194 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Implementation for setitem.""" + +from ...composite import base +from ....common import dtype as mstype +from ... import functional as F +from . import _multitype_ops_util as mult_util + +setitem = base.MultitypeFuncGraph('setitem') + +@setitem.register("List", "Number", "String") +def _list_setitem_with_string(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (String): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "Number") +def _list_setitem_with_number(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (Number): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "Tensor") +def _list_setitem_with_Tensor(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (Tensor): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("List", "Number", "List") +def _list_setitem_with_List(data, number_index, value): + """ + Assign value to list. + + Inputs: + data (list): Data of type lis. + number_index (Number): Index of data. + value (List): Value given. + + Outputs: + List, type is same as the element type of data. + """ + return F.list_setitem(data, number_index, value) + + +@setitem.register("Dictionary", "String", "Tensor") +def _dict_setitem_with_tensor(data, key, value): + """ + Assign value to dictionary. + + Inputs: + data (Dictionary): Data of type dict. + key (str): Key of the data. + value (Tensor): Value given. + + Outputs: + Dict, type is as same as the element type of data. + """ + return F.dict_setitem(data, key, value) + + +@setitem.register("Dictionary", "String", "Number") +def _dict_setitem_with_number(data, key, value): + """ + Assign value to dictionary. + + Inputs: + data (Dictionary): Data of type dict. + key (str): Key of the data. + value (Number): Value given. + + Outputs: + Dict, type is as same as the element type of data. + """ + return F.dict_setitem(data, key, value) + + +@setitem.register("Tensor", "Tensor", "Tensor") +def _tensor_setitem_by_tensor_v1(data, index, value_tensor): + """ + Tensor assignment. + + Note: + Syntax support: A[B] = U and A[A>n] = U. + Restraint condition: 1) A, U is a Tensor, and B is a bool Tensor. + 2) A.shape == B.shape + 3) U.size == 1 + 4) n is a number + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value_tensor (Tensor): Tensor with size 1. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_dtype = F.dtype(index) + index_shape = F.shape(index) + is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) + if not is_bool: + return mult_util.error_msg( + "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) + data_shape = F.shape(data) + if index_shape != data_shape: + return mult_util.error_msg( + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (data_shape, index_shape)) + size = F.size(value_tensor) + if size != 1: + return mult_util.error_msg( + "When assign value is a tensor, its size should be 1, but current size is {}.", (size,)) + dtype = F.dtype(data) + u_cast = F.cast(value_tensor, dtype) + one_data = F.ones_like(data) + u = F.tensor_mul(one_data, u_cast) + return F.select(index, u, data) + + +@setitem.register("Tensor", "Tensor", "Number") +def _tensor_setitem_by_tensor_v2(data, index, value): + """ + Tensor assignment. + + Note: + Syntax support: A[B] = u and A[A>n] = u. + Restraint condition: 1) A is a Tensor, and B is a bool Tensor. + 2) A.shape == B.shape + 3) u is a scalar + 4) n is a number + + Inputs: + data (Tensor): Assigned tensor. + index (Tensor): Tensor of bool type. + value_tensor (Number): Assignment value. + + Outputs: + Tensor, element type and shape is same as data. + """ + index_dtype = F.dtype(index) + index_shape = F.shape(index) + is_bool = mult_util.is_same_type(index_dtype, mstype.bool_) + if not is_bool: + return mult_util.error_msg( + "The tensor index should be a bool type tensor. {} type tensor is not supported yet.", (index_dtype,)) + shape = F.shape(data) + if index_shape != shape: + return mult_util.error_msg( + "The tensor(shape={}) and tensor index(shape={}) should be the same shape.", (shape, index_shape)) + dtype = F.dtype(data) + u = F.fill(dtype, shape, value) + return F.select(index, u, data) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 611c569553..0ed750beb1 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -31,6 +31,9 @@ dtype = P.DType() issubclass_ = P.IsSubClass() isinstance_ = P.IsInstance() fill = P.Fill() +select = P.Select() +size = P.Size() +ones_like = P.OnesLike() shape = P.Shape() rank = P.Rank() reshape = P.Reshape() @@ -68,7 +71,9 @@ scalar_cast = P.ScalarCast() tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') list_getitem = Primitive('list_getitem') +list_setitem = Primitive('list_setitem') dict_getitem = Primitive('dict_getitem') +dict_setitem = Primitive('dict_setitem') tuple_div = Primitive("tuple_div") tuple_len = Primitive("tuple_len") tuple_reversed = Primitive("tuple_reversed") diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index 6200d4e163..a88a2d8322 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -18,6 +18,7 @@ import pytest from mindspore import Tensor from mindspore import context +from mindspore import dtype as mstype from mindspore.nn import Cell from ....mindspore_test_framework.mindspore_test import mindspore_test @@ -79,7 +80,102 @@ class NetWorkReduceToScalar(Cell): return ret +class TensorAssignWithBoolTensorIndex(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex, self).__init__() + self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) + + def construct(self, a, b, c, u_tensor, _scalar): + a[c] = u_scalar + a[b] = u_tensor + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndexError(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndexError, self).__init__() + + def construct(self, a, b, c, u_tensor): + a[b][c] = u_tensor + return a + + +class TensorAssignWithBoolTensorIndex2(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2, self).__init__() + self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) + + def construct(self, a, u_tensor, _scalar): + a[a>8] = u_tensor + a[a>=6] = u_scalar + a[a<3] = u_scalar + a[a<=5] = u_tensor + a[a==5] = u_scalar + z = a + self.t + return z + + +class TensorAssignWithBoolTensorIndex2Error(Cell): + def __init__(self): + super(TensorAssignWithBoolTensorIndex2Error, self).__init__() + + def construct(self, a, u_tensor): + a[a>8][a>5] = u_tensor + return a + + +a = np.random.uniform(1,10,[2,3]) +b = a > 5 +c = a < 3 +Ta = Tensor(a) +Tb = Tensor(b) +Tc = Tensor(c) +Td = Tensor([True, True]) +u_tensor = Tensor([1]) +u_tensor_error = Tensor([1, 2]) +u_scalar = 5 + + +def test_tensor_assign_bool_index(): + net1 = TensorAssignWithBoolTensorIndex() + net2 = TensorAssignWithBoolTensorIndex2() + + net1(Ta, Tb, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Td, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, u_tensor, Tc, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Td, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Ta, u_tensor, u_scalar) + with pytest.raises(ValueError): + net1(Ta, Tb, Tc, u_tensor_error, u_scalar) + #net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) + with pytest.raises(ValueError): + net2(Ta, u_tensor_error, u_scalar) + net3 = TensorAssignWithBoolTensorIndexError() + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_tensor) + with pytest.raises(AttributeError): + net3(Ta, Tb, Tc, u_scalar) + net4 = TensorAssignWithBoolTensorIndex2Error() + with pytest.raises(AttributeError): + net4(Ta, u_tensor) + with pytest.raises(AttributeError): + net4(Ta, u_scalar) + + test_cases = [ + ('TensorAssignWithBoolTensorIndex', { + 'block': TensorAssignWithBoolTensorIndex(), + 'desc_inputs': [Ta, Tb, Tc, u_tensor, u_scalar], + }), + ('TensorAssignWithBoolTensorIndex2', { + 'block': TensorAssignWithBoolTensorIndex2(), + 'desc_inputs': [Ta, u_tensor, u_scalar], + }), ('SlicePositive', { 'block': NetWorkSlicePositive(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], From 0d208e00bdf2835fcf9e99455a6db9fbafce3d2c Mon Sep 17 00:00:00 2001 From: Ziyan Date: Wed, 1 Apr 2020 21:03:56 +0800 Subject: [PATCH 009/142] Model ALLTOALL as a single operator in cost model; scale the ALLTOALL, ALLGATHER, and REDUCESCATTER with different factors; change the BETA and GAMMA value in cost model. --- .../parallel/auto_parallel/graph_costmodel.h | 2 +- .../redistribution_operator_infer.cc | 33 +++++++++++------ .../redistribution_operator_infer.h | 4 ++- .../tensor_layout/tensor_redistribution.cc | 36 +++++++++++-------- .../tensor_layout/tensor_redistribution.h | 4 ++- .../parallel/test_auto_parallel_resnet.py | 4 +-- .../parallel/test_auto_parallel_two_matmul.py | 2 +- 7 files changed, 55 insertions(+), 30 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index b6591c0741..e701a377b9 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -34,7 +34,7 @@ namespace parallel { #define OPERATOR_TO_OPERATOR_CONNECTOR "-" #define DEFAULT_DEVICE_MEMORY_CAPACITY (1024.0 * 1024.0 * 1024.0 * 16.0) #define DEFAULT_COST_MODEL_ALPHA 1.0 -#define DEFAULT_COST_MODEL_BETA 260.0 +#define DEFAULT_COST_MODEL_BETA 400.0 #define DEFAULT_COST_MODEL_GAMMA 0.001 #define DEFAULT_COST_MODEL_SIMPLIFY_CALCULATION true #define DEFAULT_COST_MODEL_COMMUNI_THRESHOLD 2048.0 diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc index b4ec6a016f..ac768c19f9 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace parallel { Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, - RankList dev_list) { + RankList dev_list, bool is_cost_model) { in_tensor_map_ = tensor_layout.tensor_map(); dev_mat_ = tensor_layout.device_arrangement(); @@ -51,6 +51,8 @@ Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, cons for (int32_t item : map) { map_[key++] = item; } + + is_cost_model_ = is_cost_model; return Status::SUCCESS; } @@ -130,15 +132,26 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { std::any_of(map_.begin(), map_.end(), [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); - Args args_allconcat = {cat_dim, out_dim, dev_mat_.GetDimByReverseIdx(IntToUint(out_dim))}; - Args args_allsplit = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; - if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { - MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; - return Status::FAILED; - } - if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { - MS_LOG(ERROR) << "Insert SplitByAxis Error!"; - return Status::FAILED; + int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); + if (is_cost_model_) { + int32_t dev_dim = in_tensor_map_.GetDimByIdx(IntToUint(cat_dim)); + Args args_alltoall = {dev_mat_.GetDimByReverseIdx(IntToUint(dev_dim)), UintToInt(index), cat_dim, dev_dim, + dev_num}; + if (InsertOperator(PERMUTE_BY_AXIS, args_alltoall) == Status::FAILED) { + MS_LOG(ERROR) << "Insert PermuteByAxis Error!"; + return Status::FAILED; + } + } else { + Args args_allconcat = {cat_dim, out_dim, dev_num}; + Args args_allsplit = {dev_num, UintToInt(index), out_dim}; + if (InsertOperator(CONCAT_BY_AXIS, args_allconcat) == Status::FAILED) { + MS_LOG(ERROR) << "Insert ConcatByAxis Error!"; + return Status::FAILED; + } + if (InsertOperator(SPLIT_BY_AXIS, args_allsplit) == Status::FAILED) { + MS_LOG(ERROR) << "Insert SplitByAxis Error!"; + return Status::FAILED; + } } (void)map_.erase(iter++); map_[IntToSize(cat_dim)] = NONE; diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h index b4ec0c4633..8fd953572a 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h @@ -40,7 +40,8 @@ class RedistributionOperatorInfer { public: const int NONE = -1; explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} - Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list); + Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list, + bool is_cost_model = false); ~RedistributionOperatorInfer() = default; OperatorList operator_list() const { return operator_list_; } OperatorVector operator_vector() const { return operator_vector_; } @@ -67,6 +68,7 @@ class RedistributionOperatorInfer { ConstructOperator constructor_; RankList dev_list_; bool construct_op_flag_; + bool is_cost_model_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc index d8eef7e7a5..460cd9d1bd 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc @@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& return Status::SUCCESS; } -RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() { +RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) { // Step 1: Match device arrangement between from_ and to_ RedistributionLayoutTransfer layout_transfer; Status status = layout_transfer.Init(from_, to_); @@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL MS_LOG(DEBUG) << "reshape to_ " << to_.ToString(); // Step 2: Infer redistribution and insert operators RedistributionOperatorInfer operator_infer(construct_op_flag_); - if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) { + if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) { MS_LOG(ERROR) << "Init operatorInfer failed!"; return nullptr; } @@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const } Status TensorRedistribution::ComputeCost() { - RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(); + RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true); if (redistribution_oplist_ptr == nullptr) { MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed"; return Status::FAILED; @@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() { std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast(1.0), std::multiplies()); std::string str = op.first; if (str == PERMUTE_BY_AXIS) { - // The shape does not change after PermuteByAxis operation. - // communication cost = all_to_all + all_to_all = 2 * slice_shape - // computation cost = slice_shape - forward_comm_cost_ += prod; - backward_comm_cost_ += prod; - comm_cost_ += 2.0 * prod; - computation_cost_ += prod; - memory_cost_ += prod; + // Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost. + // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape + forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR; + comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR; + int32_t concat_dim = op.second[2]; + if (concat_dim == 0) { + // memory cost = all_gather + computation_cost_ += prod; + memory_cost_ += prod; + } else { + // memory cost = all_gather + split + concat + int32_t dev_num = op.second[4]; + computation_cost_ += (prod + prod * dev_num + prod * dev_num); + memory_cost_ += (prod * dev_num + prod * dev_num + prod); + } } else if (str == CONCAT_BY_AXIS) { // communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape // computation cost = before_slice_shape @@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() { } double dev_num = op.second[2]; // here, communication cost = all_gather + reduce_scatter - forward_comm_cost_ += prod * dev_num; - backward_comm_cost_ += prod; - comm_cost_ += prod * (dev_num + 1.0); + forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; + comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR; int32_t concat_dim = op.second[0]; if (concat_dim == 0) { // computation cost = all_gather diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h index ebaccadf53..71d4a02701 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h @@ -33,6 +33,8 @@ namespace mindspore { namespace parallel { +constexpr double ALLTOALL_SCALE_FACTOR = 2.0; +constexpr double ALLGATHER_REDUCESCATTER_SCALE_FACTOR = 0.5; class TensorRedistribution { public: explicit TensorRedistribution(bool construct_op_flag = true, bool keep_reshape = false) @@ -46,7 +48,7 @@ class TensorRedistribution { keep_reshape_(keep_reshape) {} Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); ~TensorRedistribution() = default; - RedistributionOpListPtr InferTensorRedistributionOperatorList(); + RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); OperatorList operator_list() const { return operator_list_; } bool reshape_flag() const { return reshape_flag_; } Status ComputeCost(); diff --git a/tests/ut/python/parallel/test_auto_parallel_resnet.py b/tests/ut/python/parallel/test_auto_parallel_resnet.py index 9b4e1fda23..ae7bd952d9 100644 --- a/tests/ut/python/parallel/test_auto_parallel_resnet.py +++ b/tests/ut/python/parallel/test_auto_parallel_resnet.py @@ -304,7 +304,7 @@ def train_32k_8p(epoch_size=3, batch_size=32, num_classes=32768): def test_train_32k_8p_fusion1(epoch_size=3, batch_size=32, num_classes=32768): #1048576 #131072 #32768 #8192 - cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) + cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_algorithm=1) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_times=2) cost_model_context.set_cost_model_context(costmodel_allreduce_fusion_tail_percent=0.5) @@ -651,7 +651,7 @@ def test_train_32k_8p_fusion2(epoch_size=3, batch_size=32, num_classes=32768): # def test_train_64k_8p(epoch_size=3, batch_size=32, num_classes=65536): #1048576 #131072 #32768 #8192 dev_num = 8 context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num) - cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=260.0) + cost_model_context.set_cost_model_context(costmodel_gamma=0.001, costmodel_beta=400.0) set_algo_parameters(elementwise_op_strategy_follow=True) resset_op_id() np.random.seed(6) diff --git a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py index db6190ab89..848c8025cb 100644 --- a/tests/ut/python/parallel/test_auto_parallel_two_matmul.py +++ b/tests/ut/python/parallel/test_auto_parallel_two_matmul.py @@ -86,7 +86,7 @@ def test_two_matmul(): costmodel_alpha = cost_model_context.get_cost_model_context("costmodel_alpha") assert costmodel_alpha == 1.0 costmodel_beta = cost_model_context.get_cost_model_context("costmodel_beta") - assert costmodel_beta == 260.0 + assert costmodel_beta == 400.0 costmodel_gamma = cost_model_context.get_cost_model_context("costmodel_gamma") assert costmodel_gamma == 0.001 costmodel_communi_threshold = cost_model_context.get_cost_model_context("costmodel_communi_threshold") From 8c424785fb0a937ce7c31fe080a7768556c6f7d8 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Sat, 18 Apr 2020 14:27:02 +0800 Subject: [PATCH 010/142] remove reshape pair --- .../ascend/ascend_backend_optimization.cc | 2 + .../ascend/ir_fusion/remove_reshape_pair.cc | 55 +++++++++++++++++++ .../ascend/ir_fusion/remove_reshape_pair.h | 43 +++++++++++++++ 3 files changed, 100 insertions(+) create mode 100644 mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 7a35627e25..6c245d7548 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -45,6 +45,7 @@ #include "pre_activate/ascend/ir_fusion/mul_add_fusion.h" #include "pre_activate/ascend/ir_fusion/mul_addn_fusion.h" #include "pre_activate/ascend/ir_fusion/matmul_biasadd_fusion.h" +#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" #include "pre_activate/ascend/format_type/insert_trans_op.h" #include "pre_activate/pass/getitem_tuple.h" #include "pre_activate/pass/optimize_dependence.h" @@ -113,6 +114,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc new file mode 100644 index 0000000000..5e265f2cf1 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "pre_activate/ascend/ir_fusion/remove_reshape_pair.h" +#include +#include "session/anf_runtime_algorithm.h" +#include "utils/utils.h" +#include "operator/ops.h" + +namespace mindspore { +namespace opt { +const BaseRef RemoveReshapePair::DefinePattern() const { + const auto prim_reshape = std::make_shared(prim::kPrimReshape->name()); + VectorRef reshape({prim_reshape, input_varptr_}); + + return VectorRef({prim::kPrimReshape, reshape}); +} + +const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &equiv) const { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(equiv); + auto manager = func_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_1); + // If reshape operator used by more than one other operators, reshape operator cant not be deleted directly + auto users = manager->node_users()[reshape_op_1]; + if (users.size() > 1) { + return nullptr; + } + auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum); + MS_EXCEPTION_IF_NULL(reshape_op_2); + users = manager->node_users()[reshape_op_2]; + if (users.size() > 1) { + return nullptr; + } + auto input_node = reshape_op_2->input(1); + return input_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h new file mode 100644 index 0000000000..a284f4eaa9 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/remove_reshape_pair.h @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ + +#include +#include +#include "ir/anf.h" +#include "pre_activate/common/pattern_engine.h" +#include "pre_activate/common/helper.h" +#include "pre_activate/common/optimizer.h" + +namespace mindspore { +namespace opt { +class RemoveReshapePair : public PatternProcessPass { + public: + explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) { + input_varptr_ = std::make_shared(); + } + ~RemoveReshapePair() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + VarPtr input_varptr_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_IR_FUSION_REMOVE_RESHAPE_PAIR_H_ From 2795e492ffe14fa02939a8c3c315af6b9d3dbfaf Mon Sep 17 00:00:00 2001 From: yanghaitao Date: Thu, 16 Apr 2020 15:03:41 +0800 Subject: [PATCH 011/142] TextFileDataset --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 37 +- mindspore/ccsrc/dataset/api/de_pipeline.h | 5 +- .../ccsrc/dataset/api/python_bindings.cc | 15 +- .../engine/datasetops/source/CMakeLists.txt | 1 + .../engine/datasetops/source/text_file_op.cc | 459 ++++++++++++++++++ .../engine/datasetops/source/text_file_op.h | 263 ++++++++++ mindspore/dataset/__init__.py | 6 +- mindspore/dataset/engine/__init__.py | 4 +- mindspore/dataset/engine/datasets.py | 130 ++++- mindspore/dataset/engine/iterators.py | 10 +- mindspore/dataset/engine/validators.py | 22 + mindspore/dataset/transforms/nlp/__init__.py | 20 + mindspore/dataset/transforms/nlp/utils.py | 35 ++ tests/ut/cpp/dataset/CMakeLists.txt | 2 +- tests/ut/cpp/dataset/text_file_op_test.cc | 112 +++++ .../ut/data/dataset/testTextFileDataset/1.txt | 3 + .../ut/data/dataset/testTextFileDataset/2.txt | 2 + .../dataset/test_datasets_textfileop.py | 87 ++++ 18 files changed, 1175 insertions(+), 38 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h create mode 100644 mindspore/dataset/transforms/nlp/__init__.py create mode 100644 mindspore/dataset/transforms/nlp/utils.py create mode 100644 tests/ut/cpp/dataset/text_file_op_test.cc create mode 100644 tests/ut/data/dataset/testTextFileDataset/1.txt create mode 100644 tests/ut/data/dataset/testTextFileDataset/2.txt create mode 100644 tests/ut/python/dataset/test_datasets_textfileop.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index 5f61c86f06..f6440710b1 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -28,10 +28,10 @@ #include "dataset/engine/datasetops/source/manifest_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" +#include "dataset/engine/datasetops/source/text_file_op.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" - #include "dataset/util/random.h" #include "dataset/util/status.h" #include "utils/log_adapter.h" @@ -61,7 +61,8 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kVoc, &DEPipeline::ParseVOCOp}, {kCifar10, &DEPipeline::ParseCifar10Op}, {kCifar100, &DEPipeline::ParseCifar100Op}, - {kCelebA, &DEPipeline::ParseCelebAOp}}; + {kCelebA, &DEPipeline::ParseCelebAOp}, + {kTextFile, &DEPipeline::ParseTextFileOp}}; DEPipeline::DEPipeline() : iterator_(nullptr) { try { @@ -985,5 +986,37 @@ Status DEPipeline::ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr) { + // Required arguments + std::shared_ptr builder = std::make_shared(); + if (!args["dataset_files"].is_none()) { + (void)builder->SetTextFilesList(ToStringVector(args["dataset_files"])); + } else { + RETURN_STATUS_UNEXPECTED("Error: dataset_files is missing"); + } + // Optional arguments + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "shuffle_files") { + (void)builder->SetShuffleFiles(ToBool(value)); + } else if (key == "num_samples") { + (void)builder->SetNumSamples(ToInt(value)); + } else if (key == "num_shards") { + (void)builder->SetNumDevices(ToInt(value)); + } else if (key == "shard_id") { + (void)builder->SetDeviceId(ToInt(value)); + } + } + } + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} } // namespace dataset } // namespace mindspore diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 6ff7bb091c..eadde2c191 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -58,7 +58,8 @@ enum OpName { kVoc, kCifar10, kCifar100, - kCelebA + kCelebA, + kTextFile }; // The C++ binder class that we expose to the python script. @@ -148,6 +149,8 @@ class DEPipeline { Status ParseCelebAOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseTextFileOp(const py::dict &args, std::shared_ptr *ptr); + private: // Execution tree that links the dataset operators. std::shared_ptr tree_; diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 076f2ecc36..5399e7e425 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -55,6 +55,7 @@ #include "dataset/engine/datasetops/source/sampler/weighted_random_sampler.h" #include "dataset/engine/datasetops/source/tf_reader_op.h" #include "dataset/engine/jagged_connector.h" +#include "dataset/engine/datasetops/source/text_file_op.h" #include "dataset/kernels/data/to_float16_op.h" #include "dataset/util/random.h" #include "mindrecord/include/shard_operator.h" @@ -176,6 +177,17 @@ void bindDatasetOps(py::module *m) { THROW_IF_ERROR(MnistOp::CountTotalRows(dir, numSamples, &count)); return count; }); + + (void)py::class_>(*m, "TextFileOp") + .def_static("get_num_rows", [](const py::list &files) { + int64_t count = 0; + std::vector filenames; + for (auto file : files) { + !file.is_none() ? filenames.push_back(py::str(file)) : (void)filenames.emplace_back(""); + } + THROW_IF_ERROR(TextFileOp::CountAllFileRows(filenames, &count)); + return count; + }); } void bindTensor(py::module *m) { (void)py::class_(*m, "GlobalContext") @@ -463,7 +475,8 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("VOC", OpName::kVoc) .value("CIFAR10", OpName::kCifar10) .value("CIFAR100", OpName::kCifar100) - .value("CELEBA", OpName::kCelebA); + .value("CELEBA", OpName::kCelebA) + .value("TEXTFILE", OpName::kTextFile); (void)py::enum_(m, "InterpolationMode", py::arithmetic()) .value("DE_INTER_LINEAR", InterpolationMode::kLinear) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt index a7c0dfd725..8801205f6c 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/CMakeLists.txt @@ -18,6 +18,7 @@ add_library(engine-datasetops-source OBJECT manifest_op.cc cifar_op.cc celeba_op.cc + text_file_op.cc ) add_dependencies(engine-datasetops-source mindspore::protobuf) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc new file mode 100644 index 0000000000..2b62616366 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.cc @@ -0,0 +1,459 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include "common/utils.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/core/config_manager.h" +#include "dataset/util/task_manager.h" +#include "dataset/util/wait_post.h" +#include "dataset/util/random.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/engine/execution_tree.h" + +namespace mindspore { +namespace dataset { +TextFileOp::Builder::Builder() + : builder_device_id_(0), builder_num_devices_(1), builder_num_samples_(0), builder_shuffle_files_(false) { + std::shared_ptr config_manager = GlobalContext::config_manager(); + builder_num_workers_ = config_manager->num_parallel_workers(); + builder_op_connector_size_ = config_manager->op_connector_size(); + builder_rows_per_buffer_ = config_manager->rows_per_buffer(); + builder_worker_connector_size_ = config_manager->worker_connector_size(); +} + +Status TextFileOp::Builder::ValidateInputs() const { + std::string err_msg; + err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers should be greate than 0\n" : ""; + err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); +} + +Status TextFileOp::Builder::Build(std::shared_ptr *op) { + RETURN_IF_NOT_OK(ValidateInputs()); + + // Throttle the number of workers if we have more workers than files! + if (static_cast(builder_num_workers_) > builder_text_files_list_.size()) { + builder_num_workers_ = builder_text_files_list_.size(); + MS_LOG(WARNING) << "TextFileOp operator parallelism reduced to " << builder_num_workers_ << " workers."; + } + + builder_schema_ = std::make_unique(); + RETURN_IF_NOT_OK( + builder_schema_->AddColumn(ColDescriptor("text", DataType(DataType::DE_UINT8), TensorImpl::kFlexible, 1))); + + std::shared_ptr text_file_op = std::make_shared( + builder_num_workers_, builder_rows_per_buffer_, builder_num_samples_, builder_worker_connector_size_, + std::move(builder_schema_), builder_text_files_list_, builder_op_connector_size_, builder_shuffle_files_, + builder_num_devices_, builder_device_id_); + RETURN_IF_NOT_OK(text_file_op->Init()); + *op = std::move(text_file_op); + + return Status::OK(); +} + +TextFileOp::TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + std::unique_ptr schema, std::vector text_files_list, + int32_t op_connector_size, bool shuffle_files, int32_t num_device, int32_t device_id) + : ParallelOp(num_workers, op_connector_size), + device_id_(device_id), + num_devices_(num_device), + rows_per_buffer_(rows_per_buffer), + num_samples_(num_samples), + text_files_list_(std::move(text_files_list)), + shuffle_files_(shuffle_files), + data_schema_(std::move(schema)), + all_num_rows_(0), + num_rows_per_shard_(0), + filename_index_(std::make_unique()), + finished_reading_dataset_(false), + load_io_block_queue_(true), + load_jagged_connector_(true) { + worker_connector_size_ = worker_connector_size; +} + +Status TextFileOp::Init() { + RETURN_IF_NOT_OK(filename_index_->insert(text_files_list_)); + + int32_t safe_queue_size = static_cast(std::ceil(text_files_list_.size() / num_workers_) + 1); + io_block_queues_.Init(num_workers_, safe_queue_size); + + for (int32_t i = 0; i < data_schema_->NumColumns(); ++i) { + col_name_map_[data_schema_->column(i).name()] = i; + } + + RETURN_IF_NOT_OK(ParallelOp::CreateWorkerConnector(worker_connector_size_)); + + jagged_buffer_connector_ = std::make_unique(num_workers_, 1, worker_connector_size_); + return Status::OK(); +} + +Status TextFileOp::Reset() { + load_jagged_connector_ = true; + load_io_block_queue_ = true; + + RETURN_IF_NOT_OK(ParallelOp::Reset()); + NotifyToFillIOBlockQueue(); + return Status::OK(); +} + +Status TextFileOp::LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row) { + TensorRow tRow(1, nullptr); + (*tensor_table)->push_back(std::move(tRow)); + + std::shared_ptr tensor; + RETURN_IF_NOT_OK( + Tensor::CreateTensor(&tensor, data_schema_->column(0).tensorImpl(), + TensorShape(std::vector(1, line.size())), data_schema_->column(0).type(), + const_cast(reinterpret_cast(common::SafeCStr(line))))); + (**tensor_table)[row][0] = std::move(tensor); + return Status::OK(); +} + +Status TextFileOp::LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id) { + std::ifstream handle(file); + if (!handle.is_open()) { + RETURN_STATUS_UNEXPECTED("Failed to open file " + file); + } + + int64_t rows_each_buffer = 0; + int64_t rows_total = 0; + std::string line; + std::unique_ptr cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + cur_buffer->set_column_name_map(col_name_map_); + std::unique_ptr tensor_table = std::make_unique(); + + while (getline(handle, line)) { + // If read to the end offset of this file, break. + if (rows_total >= end_offset) { + break; + } + // Skip line before start offset. + if (rows_total < start_offset) { + rows_total++; + continue; + } + + RETURN_IF_NOT_OK(LoadTensor(line, &tensor_table, rows_each_buffer)); + rows_each_buffer++; + rows_total++; + if (rows_each_buffer == rows_per_buffer_) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + + cur_buffer = std::make_unique(0, DataBuffer::BufferFlags::kDeBFlagNone); + cur_buffer->set_column_name_map(col_name_map_); + tensor_table = std::make_unique(); + rows_each_buffer = 0; + } + } + + if (rows_each_buffer > 0) { + cur_buffer->set_tensor_table(std::move(tensor_table)); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(cur_buffer))); + } + + return Status::OK(); +} + +Status TextFileOp::WorkerEntry(int32_t worker_id) { + TaskManager::FindMe()->Post(); + + std::unique_ptr io_block; + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + while (!io_block->eof()) { + if (!io_block->eoe()) { + if (load_jagged_connector_) { + std::string filename; + RETURN_IF_NOT_OK(io_block->GetFilename(&filename, *filename_index_)); + int64_t start_offset = io_block->GetStartOffset(); + int64_t end_offset = io_block->GetEndOffset(); + RETURN_IF_NOT_OK(LoadFile(filename, start_offset, end_offset, worker_id)); + } + } else { + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(jagged_buffer_connector_->Add(worker_id, std::move(eoe_buffer))); + } + + RETURN_IF_NOT_OK(PopIoBlockQueue(worker_id, &io_block)); + } + return Status::OK(); +} + +// Pops an element from a queue in io_block_queues +Status TextFileOp::PopIoBlockQueue(int32_t index, std::unique_ptr *out_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->PopFront(out_block)); + + return Status::OK(); +} + +// Pushes an element to a queue in io_block_queues +Status TextFileOp::PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block) { + RETURN_IF_NOT_OK(io_block_queues_[index]->Add(std::move(io_block))); + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. +// When the worker pops this control indicator, it will shut itself down gracefully. +Status TextFileOp::PostEndOfData() { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eof = std::make_unique(IOBlock::kDeIoBlockFlagEof); + RETURN_IF_NOT_OK(PushIoBlockQueue(i, std::move(eof))); + } + + return Status::OK(); +} + +// Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker +// pops this control indicator, it will wait until the next epoch starts and then resume execution. +Status TextFileOp::PostEndOfEpoch(int32_t queue_index) { + for (int i = 0; i < num_workers_; ++i) { + std::unique_ptr eoe = std::make_unique(IOBlock::kDeIoBlockFlagEoe); + RETURN_IF_NOT_OK(PushIoBlockQueue((queue_index + i) % num_workers_, std::move(eoe))); + } + + return Status::OK(); +} + +static void ShuffleKeys(std::vector *i_keys, uint32_t seed) { + std::mt19937 rng(seed); + std::shuffle(i_keys->begin(), i_keys->end(), rng); +} + +bool TextFileOp::NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count) { + *start_offset = 0; + *end_offset = 0; + bool push = false; + int64_t start_index = device_id_ * num_rows_per_shard_; + if (device_id_ + 1 < 0) { + MS_LOG(ERROR) << "Device id is invalid"; + return false; + } + + int64_t end_index = (static_cast(device_id_) + 1) * num_rows_per_shard_; + if (pre_count <= start_index && pre_count + filename_numrows_[file_name] > start_index) { + *start_offset = start_index - pre_count; + push = true; + if (pre_count < end_index && pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + if (pre_count >= start_index && pre_count < end_index) { + *start_offset = 0; + push = true; + if (pre_count + filename_numrows_[file_name] >= end_index) { + *end_offset = end_index - pre_count; + } else { + *end_offset = filename_numrows_[file_name]; + } + } + + return push; +} + +Status TextFileOp::FillIOBlockQueue(const std::vector &i_keys) { + int32_t queue_index = 0; + int64_t pre_count = 0; + int64_t start_offset = 0; + int64_t end_offset = 0; + bool finish = false; + while (!finish) { + std::vector> file_index; + if (!i_keys.empty()) { + for (auto it = i_keys.begin(); it != i_keys.end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + auto file_it = filename_index_->Search(*it); + file_index.emplace_back(std::pair(file_it.value(), *it)); + } + } else { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + { + if (!load_io_block_queue_) { + break; + } + } + file_index.emplace_back(std::pair(it.value(), it.key())); + } + } + for (auto file_info : file_index) { + if (NeedPushFileToBlockQueue(file_info.first, &start_offset, &end_offset, pre_count)) { + auto ioBlock = + std::make_unique(file_info.second, start_offset, end_offset, IOBlock::kDeIoBlockNone); + RETURN_IF_NOT_OK(PushIoBlockQueue(queue_index, std::move(ioBlock))); + queue_index = (queue_index + 1) % num_workers_; + } + + pre_count += filename_numrows_[file_info.first]; + } + + if (pre_count < (static_cast(device_id_) + 1) * num_rows_per_shard_) { + finish = false; + } else { + finish = true; + } + } + + RETURN_IF_NOT_OK(PostEndOfEpoch(queue_index)); + return Status::OK(); +} + +Status TextFileOp::WaitToFillIOBlockQueue() { + // must be called first if called by worker spanwed by taskgroup + TaskManager::FindMe()->Post(); + + std::vector i_keys; + if (shuffle_files_) { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + i_keys.push_back(it.key()); + } + } + uint32_t seed = 0; + while (true) { + RETURN_IF_NOT_OK(io_block_queue_wait_post_.Wait()); + io_block_queue_wait_post_.Clear(); + + if (finished_reading_dataset_) { + break; + } + + if (shuffle_files_) { + ShuffleKeys(&i_keys, num_devices_ == 1 ? GetSeed() : ++seed); + } + RETURN_IF_NOT_OK(FillIOBlockQueue(i_keys)); + } + return Status::OK(); +} + +void TextFileOp::NotifyToFillIOBlockQueue() { io_block_queue_wait_post_.Set(); } + +Status TextFileOp::operator()() { + RETURN_IF_NOT_OK(CalculateNumRowsPerShard()); + + // launch one thread, responsible for filling IoBlockQueue + RETURN_IF_NOT_OK(tree_->LaunchWorkers(1, std::bind(&TextFileOp::WaitToFillIOBlockQueue, this))); + + // Read data from disk into buffers + RETURN_IF_NOT_OK( + tree_->LaunchWorkers(num_workers_, std::bind(&TextFileOp::WorkerEntry, this, std::placeholders::_1))); + + // must be called after launching workers. + TaskManager::FindMe()->Post(); + + io_block_queue_wait_post_.Register(tree_->AllTasks()); + NotifyToFillIOBlockQueue(); + while (!finished_reading_dataset_) { + int64_t buffer_id = 0; + int32_t workers_done = 0; + int64_t rows_read = 0; + load_io_block_queue_ = true; + + while (workers_done < num_workers_) { + std::unique_ptr buffer; + RETURN_IF_NOT_OK(jagged_buffer_connector_->Pop(0, &buffer)); + if (buffer->eoe()) { + workers_done++; + } else if (num_samples_ == 0 || rows_read < num_samples_) { + if ((num_samples_ > 0) && (rows_read + buffer->NumRows() > num_samples_)) { + int64_t rowsToRemove = buffer->NumRows() - (num_samples_ - rows_read); + RETURN_IF_NOT_OK(buffer->SliceOff(rowsToRemove)); + } + rows_read += buffer->NumRows(); + buffer->set_id(buffer_id++); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(buffer))); + } else { + // end of epoch + load_jagged_connector_ = false; + load_io_block_queue_ = false; + } + } + + std::unique_ptr eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eoe_buffer))); + + if (!BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat)) { + finished_reading_dataset_ = true; + NotifyToFillIOBlockQueue(); + } else { + jagged_buffer_connector_->DoReset(); + buffer_id = 0; + } + } + + std::unique_ptr eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(eof_buffer))); + + RETURN_IF_NOT_OK(PostEndOfData()); + + return Status::OK(); +} + +int64_t TextFileOp::CountTotalRows(const std::string &file) { + std::ifstream handle(file); + if (!handle.is_open()) { + MS_LOG(ERROR) << "Failed to open file: " << file; + return 0; + } + + std::string line; + int64_t count = 0; + while (getline(handle, line)) { + count++; + } + + return count; +} + +Status TextFileOp::CalculateNumRowsPerShard() { + for (auto it = filename_index_->begin(); it != filename_index_->end(); ++it) { + int64_t count = CountTotalRows(it.value()); + filename_numrows_[it.value()] = count; + all_num_rows_ += count; + } + if (all_num_rows_ == 0) { + RETURN_STATUS_UNEXPECTED("Number of rows can not be zero"); + } + + num_rows_per_shard_ = static_cast(std::ceil(all_num_rows_ * 1.0 / num_devices_)); + MS_LOG(DEBUG) << "Number rows per shard is " << num_rows_per_shard_; + return Status::OK(); +} + +Status TextFileOp::CountAllFileRows(const std::vector &files, int64_t *count) { + std::shared_ptr op; + *count = 0; + RETURN_IF_NOT_OK(Builder().SetTextFilesList(files).Build(&op)); + for (auto file : files) { + *count += op->CountTotalRows(file); + } + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h new file mode 100644 index 0000000000..49f224ffc3 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/text_file_op.h @@ -0,0 +1,263 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ +#define DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ + +#include +#include +#include +#include +#include +#include + +#include "dataset/util/status.h" +#include "dataset/util/auto_index.h" +#include "dataset/engine/data_schema.h" +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/engine/datasetops/source/io_block.h" +#include "dataset/util/queue.h" +#include "dataset/util/wait_post.h" +#include "dataset/engine/jagged_connector.h" + +namespace mindspore { +namespace dataset { +using StringIndex = AutoIndexObj; + +class TextFileOp : public ParallelOp { + public: + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Checks if the inputs of the builder is valid. + // @return Status - the error code returned. + Status ValidateInputs() const; + + // Create the final object. + // @param op - dataset op. + // @return - the error code return. + Status Build(std::shared_ptr *op); + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int64_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumDevices(int64_t num_dev) { + builder_num_devices_ = num_dev; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetDeviceId(int64_t dev_id) { + builder_device_id_ = dev_id; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetTextFilesList(const std::vector &files_list) { + builder_text_files_list_ = files_list; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetShuffleFiles(bool shuffle_files) { + builder_shuffle_files_ = shuffle_files; + return *this; + } + + // Setter method. + // @return Builder - setter method returns reference to the builder. + Builder &SetNumSamples(int64_t num_samples) { + builder_num_samples_ = num_samples; + return *this; + } + + private: + int32_t builder_device_id_; + int32_t builder_num_devices_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + int64_t builder_rows_per_buffer_; + int64_t builder_num_samples_; + int32_t builder_worker_connector_size_; + std::vector builder_text_files_list_; + bool builder_shuffle_files_; + std::unique_ptr builder_schema_; + }; + + // Constructor of TextFileOp + // @note The builder class should be used to call this constructor. + // @param num_workers - number of worker threads reading data from tf_file files. + // @param rows_per_buffer - number of rows that a full buffer will contain. + // @param total_num_rows - number of rows to read + // @param dataset_files_list - list of filepaths for the dataset files. + // @param data_schema - the data schema object. + // @param op_connector_size - size of each queue in the connector that the child operator pulls from. + // @param columns_to_load - the names of the columns to load data from. + // @param shuffle_files - whether or not to shuffle the files before reading data. + // @param equal_rows_per_shard - whether or not to get equal rows for each process. + TextFileOp(int32_t num_workers, int64_t rows_per_buffer, int64_t num_samples, int32_t worker_connector_size, + std::unique_ptr, std::vector text_files_list, int32_t op_connector_size, + bool shuffle_files, int32_t num_devices, int32_t device_id); + + // Default destructor + ~TextFileOp() = default; + + // Instantiates the internal queues and connectors + // @return Status - the error code returned + Status Init(); + + // Class functor operator () override. + // All dataset operators operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - the error code returned. + Status operator()() override; + + // Overrides base class reset method. Cleans up any state info from it's previous execution + // reinitializes itself so that it can be executed again, as if it was just created. + // @return Status - the error code returned. + Status Reset() override; + + // Get total rows in files. + // @param files - all text files. + // @param count - number of rows. + // @return Status - the error coed returned. + static Status CountAllFileRows(const std::vector &files, int64_t *count); + + private: + // The entry point for when workers are launched. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status WorkerEntry(int32_t worker_id) override; + + // Parses a single row and puts the data into a tensor table. + // @param line - the content of the row. + // @param tensor_table - the tensor table to put the parsed data in. + // @param row - the id of the row filled in the tensor table. + // @return Status - the error code returned. + Status LoadTensor(const std::string &line, std::unique_ptr *tensor_table, int64_t row); + + // Reads a text file and loads the data into multiple buffers. + // @param file - the file to read. + // @param start_offset - the start offset of file. + // @param end_offset - the end offset of file. + // @param worker_id - the id of the worker that is executing this function. + // @return Status - the error code returned. + Status LoadFile(const std::string &file, const int64_t start_offset, const int64_t end_offset, + const int32_t worker_id); + + // Calculate number of rows in each shard. + // @return Status - the error code returned. + Status CalculateNumRowsPerShard(); + + // Count number of rows in each file. + // @param filename - text file name. + // @return int64_t - the total number of rows in file. + int64_t CountTotalRows(const std::string &file); + + // Notifies the thread which called FillIoBlockQueue to resume execution + void NotifyToFillIOBlockQueue(); + + // Called asynchronously by another thread. Will wait until notified to fill the IOBlockQueue. + // @return Status - the error code returned. + Status WaitToFillIOBlockQueue(); + + // Fill the IOBlockQueue. + // @para i_keys - keys of file to fill to the IOBlockQueue + // @return Status - the error code returned. + Status FillIOBlockQueue(const std::vector &i_keys); + + // Select file and push it to the block queue. + // @param file_name - File name. + // @param start_file - If file contains the first sample of data. + // @param end_file - If file contains the end sample of data. + // @param pre_count - Total rows of previous files. + // @return Status - the error code returned. + bool NeedPushFileToBlockQueue(const std::string &file_name, int64_t *start_offset, int64_t *end_offset, + const int64_t &pre_count); + + // Pops an element from a queue in IOBlockQueue. + // @param index - the index of the queue to pop from. + // @param out_block - the popped element. + // @return Status - the error code returned. + Status PopIoBlockQueue(int32_t index, std::unique_ptr *out_block); + + // Pushes an element to a queue in IOBlockQueue. + // @param index - the index of the queue to push to. + // @param io_block - the element to push onto the queue. + // @return Status - the error code returned. + Status PushIoBlockQueue(int32_t index, std::unique_ptr &&io_block); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. + // When the worker pops this control indicator, it will shut itself down gracefully. + // @return Status - the error code returned. + Status PostEndOfData(); + + // Pushes a control indicator onto the IOBlockQueue for each worker to consume. When the worker + // pops this control indicator, it will wait until the next epoch starts and then resume execution. + // @return Status - the error code returned. + Status PostEndOfEpoch(int32_t queue_index); + + int32_t device_id_; + int32_t num_devices_; + int64_t rows_per_buffer_; + int64_t num_samples_; + std::vector text_files_list_; + bool shuffle_files_; + std::unique_ptr data_schema_; + int64_t all_num_rows_; + int64_t num_rows_per_shard_; + std::map filename_numrows_; + std::unique_ptr filename_index_; + QueueList> io_block_queues_; + WaitPost io_block_queue_wait_post_; + bool finished_reading_dataset_; + bool load_io_block_queue_; + bool load_jagged_connector_; + std::unordered_map col_name_map_; + std::unique_ptr jagged_buffer_connector_; +}; +} // namespace dataset +} // namespace mindspore +#endif // DATASET_ENGINE_DATASETOPS_SOURCE_TEXT_FILE_OP_H_ diff --git a/mindspore/dataset/__init__.py b/mindspore/dataset/__init__.py index 479c66045f..2a30b616ad 100644 --- a/mindspore/dataset/__init__.py +++ b/mindspore/dataset/__init__.py @@ -20,8 +20,8 @@ can also create samplers with this module to sample data. from .core.configuration import config from .engine.datasets import StorageDataset, TFRecordDataset, ImageFolderDatasetV2, MnistDataset, MindDataset, \ - GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, Schema, \ - Shuffle, zip + GeneratorDataset, ManifestDataset, Cifar10Dataset, Cifar100Dataset, VOCDataset, CelebADataset, TextFileDataset, \ + Schema, Shuffle, zip from .engine.samplers import DistributedSampler, PKSampler, RandomSampler, SequentialSampler, SubsetRandomSampler, \ WeightedRandomSampler from .engine.serializer_deserializer import serialize, deserialize, show @@ -29,5 +29,5 @@ from .engine.serializer_deserializer import serialize, deserialize, show __all__ = ["config", "ImageFolderDatasetV2", "MnistDataset", "StorageDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", + "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler", "zip"] diff --git a/mindspore/dataset/engine/__init__.py b/mindspore/dataset/engine/__init__.py index 720b56b96d..86d2971332 100644 --- a/mindspore/dataset/engine/__init__.py +++ b/mindspore/dataset/engine/__init__.py @@ -33,5 +33,5 @@ __all__ = ["config", "ConfigurationManager", "zip", "StorageDataset", "ImageFolderDatasetV2", "MnistDataset", "MindDataset", "GeneratorDataset", "TFRecordDataset", "ManifestDataset", "Cifar10Dataset", "Cifar100Dataset", "CelebADataset", - "VOCDataset", "Schema", "DistributedSampler", "PKSampler", "RandomSampler", - "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] + "VOCDataset", "TextFileDataset", "Schema", "DistributedSampler", "PKSampler", + "RandomSampler", "SequentialSampler", "SubsetRandomSampler", "WeightedRandomSampler"] diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8de56a6dff..ca717643c9 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -29,7 +29,7 @@ from importlib import import_module import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ - MindRecordOp, CBatchInfo + MindRecordOp, TextFileOp, CBatchInfo from mindspore._c_expression import typing from mindspore import log as logger @@ -38,7 +38,7 @@ from .iterators import DictIterator, TupleIterator from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_zip_dataset, check_add_column + check_zip_dataset, check_add_column, check_textfiledataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -888,6 +888,29 @@ class SourceDataset(Dataset): # No need for __init__ since it is the same as the super's init + @staticmethod + def _find_files(patterns): + """ + Utility function to search for files with the given glob patterns. + + Args: + patterns (str or list[str]): string or list of patterns to be searched. + + Returns: + List, files. + """ + + def flat(lists): + return list(np.array(lists).flatten()) + + if not isinstance(patterns, list): + patterns = [patterns] + + file_list = flat([glob.glob(file, recursive=True) for file in patterns]) + if file_list: # not empty + return file_list + raise ValueError("The list of path names matching the patterns is empty.") + class DatasetOp(Dataset): """ @@ -2126,30 +2149,6 @@ class TFRecordDataset(SourceDataset): >>> # 3) get all rows from dataset_files with schema file "./schema.json": >>> tfdataset = ds.TFRecordDataset(dataset_files=dataset_files, schema="./schema.json") """ - - @staticmethod - def _find_files(patterns): - """ - Utility function to search for files with the given glob patterns. - - Args: - patterns (str or list[str]): string or list of patterns to be searched. - - Returns: - List, files. - """ - - def flat(lists): - return list(np.array(lists).flatten()) - - if not isinstance(patterns, list): - patterns = [patterns] - - file_list = flat([glob.glob(file, recursive=True) for file in patterns]) - if file_list: # not empty - return file_list - raise ValueError("The list of path names matching the patterns is empty.") - @check_tfrecorddataset def __init__(self, dataset_files, schema=None, columns_list=None, num_samples=None, num_parallel_workers=None, shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None, shard_equal_rows=False): @@ -2952,3 +2951,82 @@ class CelebADataset(SourceDataset): args["num_shards"] = self.num_shards args["shard_id"] = self.shard_id return args + +class TextFileDataset(SourceDataset): + """ + A source dataset that reads and parses datasets stored on disk in text format. + The generated dataset has one columns ['text']. + + Args: + dataset_files (str or list[str]): String or list of files to be read or glob strings to search for a pattern of + files. The list will be sorted in a lexicographical order. + num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_parallel_workers (int, optional): number of workers to read the data + (default=None, number set in the config). + shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). + If shuffle is False, no shuffling will be performed; + If shuffle is True, the behavior is the same as setting shuffle to be Shuffle.GLOBAL + Otherwise, there are two levels of shuffling: + + - Shuffle.GLOBAL: Shuffle both the files and samples. + + - Shuffle.FILES: Shuffle files only. + + num_shards (int, optional): Number of shards that the dataset should be divided into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. + Examples: + >>> import mindspore.dataset as ds + >>> dataset_files = ["/path/to/1", "/path/to/2"] # contains 1 or multiple text files + >>> dataset = ds.TextFileDataset(dataset_files=dataset_files) + """ + + @check_textfiledataset + def __init__(self, dataset_files, num_samples=None, num_parallel_workers=None, + shuffle=Shuffle.GLOBAL, num_shards=None, shard_id=None): + super().__init__(num_parallel_workers) + self.dataset_files = self._find_files(dataset_files) + self.dataset_files.sort() + self.num_samples = num_samples + + if not isinstance(shuffle, (bool, Shuffle)): + raise TypeError("shuffle should be of boolean or enum 'Shuffle'.") + if not isinstance(shuffle, Shuffle): + if shuffle: + self.shuffle_level = Shuffle.GLOBAL + self.shuffle_files = True + else: + self.shuffle_level = None + self.shuffle_files = False + else: + self.shuffle_level = shuffle + self.shuffle_files = True + + self.num_shards = num_shards + self.shard_id = shard_id + + def get_args(self): + args = super().get_args() + args["dataset_files"] = self.dataset_files + args["num_samples"] = self.num_samples + if self.shuffle_files is not None: + args["shuffle_files"] = self.shuffle_files + args["shuffle"] = self.shuffle_level + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + + Return: + Number, number of batches. + """ + if self._dataset_size is None: + num_rows = TextFileOp.get_num_rows(self.dataset_files) + num_rows = get_num_rows(num_rows, self.num_shards) + if self.num_samples is None: + return num_rows + return min(self.num_samples, num_rows) + return self._dataset_size diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 2bb130f303..a74d69b9c7 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -48,12 +48,16 @@ def alter_tree(node): def _alter_node(node): """Performing some alteration to a dataset node. A common alteration is to insert a node.""" - if isinstance(node, de.TFRecordDataset) and node.shuffle_level == de.Shuffle.GLOBAL: + if isinstance(node, (de.TFRecordDataset, de.TextFileDataset)) and node.shuffle_level == de.Shuffle.GLOBAL: # Remove the connection between the parent's node to the current node because we are inserting a node. if node.output: node.output.pop() # Perform a fast scan for average rows per file - avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) + if isinstance(node, de.TFRecordDataset): + avg_rows_per_file = node.get_dataset_size(True) // len(node.dataset_files) + else: + avg_rows_per_file = node.get_dataset_size() // len(node.dataset_files) + # Shuffle between 4 files with a minimum size of 10000 rows new_shuffle = node.shuffle(max(avg_rows_per_file * 4, 10000)) return new_shuffle @@ -157,6 +161,8 @@ class Iterator: op_type = OpName.CIFAR100 elif isinstance(dataset, de.CelebADataset): op_type = OpName.CELEBA + elif isinstance(dataset, de.TextFileDataset): + op_type = OpName.TEXTFILE else: raise ValueError("Unsupported DatasetOp") diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b74e913202..a340eb5aff 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -849,3 +849,25 @@ def check_add_column(method): return method(*args, **kwargs) return new_method + + +def check_textfiledataset(method): + """A wrapper that wrap a parameter checker to the original Dataset(TextFileDataset).""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] + + # check dataset_files; required argument + dataset_files = param_dict.get('dataset_files') + if dataset_files is None: + raise ValueError("dataset_files is not provided.") + if not isinstance(dataset_files, (str, list)): + raise TypeError("dataset_files should be of type str or a list of strings.") + + check_param_type(nreq_param_int, param_dict, int) + + return method(*args, **kwargs) + + return new_method diff --git a/mindspore/dataset/transforms/nlp/__init__.py b/mindspore/dataset/transforms/nlp/__init__.py new file mode 100644 index 0000000000..01d425e2eb --- /dev/null +++ b/mindspore/dataset/transforms/nlp/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module is to support nlp augmentations. It includes two parts: +c_transforms and py_transforms. C_transforms is a high performance +image augmentation module which is developed with c++ opencv. Py_transforms +provide more kinds of image augmentations which is developed with python PIL. +""" +from .utils import as_text diff --git a/mindspore/dataset/transforms/nlp/utils.py b/mindspore/dataset/transforms/nlp/utils.py new file mode 100644 index 0000000000..adcc7cc71d --- /dev/null +++ b/mindspore/dataset/transforms/nlp/utils.py @@ -0,0 +1,35 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Some basic function for nlp +""" +import numpy as np + +def as_text(array, encoding='utf8'): + """ + Convert data of array to unicode. + + Args: + array (numpy array): Data of array should be ASCII values of each character after converted. + encoding (string): Indicating the charset for decoding. + Returns: + A 'str' object. + + """ + + if not isinstance(array, np.ndarray): + raise ValueError('input should be a numpy array') + + byte_array = bytearray(list(array)) + return byte_array.decode(encoding) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index ae9c46e62c..b05f12eee1 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -65,7 +65,7 @@ SET(DE_UT_SRCS cifar_op_test.cc celeba_op_test.cc take_op_test.cc - ) + text_file_op_test.cc) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/text_file_op_test.cc b/tests/ut/cpp/dataset/text_file_op_test.cc new file mode 100644 index 0000000000..7887eda955 --- /dev/null +++ b/tests/ut/cpp/dataset/text_file_op_test.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include + +#include "dataset/core/client.h" +#include "common/common.h" +#include "common/utils.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" +#include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/util/status.h" + +namespace common = mindspore::common; + +using namespace mindspore::dataset; +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestTextFileOp : public UT::DatasetOpTesting { + +}; + +TEST_F(MindDataTestTextFileOp, TestTextFileBasic) { + // Start with an empty execution tree + auto tree = std::make_shared(); + + std::string dataset_path; + dataset_path = datasets_root_path_ + "/testTextFileDataset/1.txt"; + + std::shared_ptr op; + TextFileOp::Builder builder; + builder.SetTextFilesList({dataset_path}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16) + .SetOpConnectorSize(2); + + Status rc = builder.Build(&op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssociateNode(op); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->AssignRoot(op); + ASSERT_TRUE(rc.IsOk()); + + MS_LOG(INFO) << "Launching tree and begin iteration."; + rc = tree->Prepare(); + ASSERT_TRUE(rc.IsOk()); + + rc = tree->Launch(); + ASSERT_TRUE(rc.IsOk()); + + // Start the loop of reading tensors from our pipeline + DatasetIterator di(tree); + TensorRow tensor_list; + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + + int row_count = 0; + while (!tensor_list.empty()) { + // Display the tensor by calling the printer on it + for (int i = 0; i < tensor_list.size(); i++) { + std::ostringstream ss; + ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; + MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + } + + rc = di.FetchNextTensorRow(&tensor_list); + ASSERT_TRUE(rc.IsOk()); + row_count++; + } + + ASSERT_EQ(row_count, 3); +} + +TEST_F(MindDataTestTextFileOp, TestTotalRows) { + std::string tf_file1 = datasets_root_path_ + "/testTextFileDataset/1.txt"; + std::string tf_file2 = datasets_root_path_ + "/testTextFileDataset/2.txt"; + std::vector files; + files.push_back(tf_file1); + int64_t total_rows = 0; + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 3); + files.clear(); + + files.push_back(tf_file2); + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 2); + files.clear(); + + files.push_back(tf_file1); + files.push_back(tf_file2); + TextFileOp::CountAllFileRows(files, &total_rows); + ASSERT_EQ(total_rows, 5); + files.clear(); +} diff --git a/tests/ut/data/dataset/testTextFileDataset/1.txt b/tests/ut/data/dataset/testTextFileDataset/1.txt new file mode 100644 index 0000000000..9d911eacc0 --- /dev/null +++ b/tests/ut/data/dataset/testTextFileDataset/1.txt @@ -0,0 +1,3 @@ +This is a text file. +Be happy every day. +Good luck to everyone. diff --git a/tests/ut/data/dataset/testTextFileDataset/2.txt b/tests/ut/data/dataset/testTextFileDataset/2.txt new file mode 100644 index 0000000000..7382722eb8 --- /dev/null +++ b/tests/ut/data/dataset/testTextFileDataset/2.txt @@ -0,0 +1,2 @@ +Another file. +End of file. diff --git a/tests/ut/python/dataset/test_datasets_textfileop.py b/tests/ut/python/dataset/test_datasets_textfileop.py new file mode 100644 index 0000000000..720fcdcce0 --- /dev/null +++ b/tests/ut/python/dataset/test_datasets_textfileop.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import mindspore.dataset as ds +from mindspore import log as logger +import mindspore.dataset.transforms.nlp.utils as nlp + +DATA_FILE = "../data/dataset/testTextFileDataset/1.txt" +DATA_ALL_FILE = "../data/dataset/testTextFileDataset/*" + +def test_textline_dataset_one_file(): + data = ds.TextFileDataset(DATA_FILE) + count = 0 + for i in data.create_dict_iterator(): + logger.info("{}".format(i["text"])) + count += 1 + assert(count == 3) + +def test_textline_dataset_all_file(): + data = ds.TextFileDataset(DATA_ALL_FILE) + count = 0 + for i in data.create_dict_iterator(): + logger.info("{}".format(i["text"])) + count += 1 + assert(count == 5) + +def test_textline_dataset_totext(): + data = ds.TextFileDataset(DATA_ALL_FILE, shuffle=False) + count = 0 + line = ["This is a text file.", "Another file.", "Be happy every day.", "End of file.", "Good luck to everyone."] + for i in data.create_dict_iterator(): + str = nlp.as_text(i["text"]) + assert(str == line[count]) + count += 1 + assert(count == 5) + +def test_textline_dataset_num_samples(): + data = ds.TextFileDataset(DATA_FILE, num_samples=2) + count = 0 + for i in data.create_dict_iterator(): + count += 1 + assert(count == 2) + +def test_textline_dataset_distribution(): + data = ds.TextFileDataset(DATA_ALL_FILE, num_shards=2, shard_id=1) + count = 0 + for i in data.create_dict_iterator(): + count += 1 + assert(count == 3) + +def test_textline_dataset_repeat(): + data = ds.TextFileDataset(DATA_FILE, shuffle=False) + data = data.repeat(3) + count = 0 + line = ["This is a text file.", "Be happy every day.", "Good luck to everyone.", + "This is a text file.", "Be happy every day.", "Good luck to everyone.", + "This is a text file.", "Be happy every day.", "Good luck to everyone."] + for i in data.create_dict_iterator(): + str = nlp.as_text(i["text"]) + assert(str == line[count]) + count += 1 + assert(count == 9) + +def test_textline_dataset_get_datasetsize(): + data = ds.TextFileDataset(DATA_FILE) + size = data.get_dataset_size() + assert(size == 3) + +if __name__ == "__main__": + test_textline_dataset_one_file() + test_textline_dataset_all_file() + test_textline_dataset_totext() + test_textline_dataset_num_samples() + test_textline_dataset_distribution() + test_textline_dataset_repeat() + test_textline_dataset_get_datasetsize() From bc2df2c913f83e54864b4e3d3637e2c47f615c6d Mon Sep 17 00:00:00 2001 From: YuJianfeng Date: Sat, 18 Apr 2020 16:48:51 +0800 Subject: [PATCH 012/142] Fix inputs size and attr for AddN fission pass --- .../ascend/ir_fission/addn_fission.cc | 16 ++++++++++------ mindspore/ccsrc/utils/utils.h | 2 +- .../pre_activate/addn_fission_test.py | 11 +++-------- 3 files changed, 14 insertions(+), 15 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc index f6eb6aca64..b9a86f7bcb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/addn_fission.cc @@ -34,6 +34,8 @@ AnfNodePtr CreateNewAddn(const FuncGraphPtr &func_graph, const CNodePtr &origin_ new_addn->set_scope(origin_addn_cnode->scope()); new_addn->set_abstract(origin_addn_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(offset)), new_addn); + std::vector dyn_input_sizes{SizeToInt(offset)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), new_addn); return new_addn; } } // namespace @@ -55,22 +57,24 @@ const AnfNodePtr AddnFission::Process(const FuncGraphPtr &func_graph, const AnfN } CNodePtr new_cnode = cnode; while (origin_input_size > inputs_divisor_) { + MS_EXCEPTION_IF_NULL(new_cnode); std::vector base_addn_inputs{NewValueNode(std::make_shared(prim::kPrimAddN->name()))}; size_t cur_input_index = 1; - // Divide the inputs of addn by 63. - while (origin_input_size - cur_input_index + 1 > inputs_divisor_) { + // Divide the inputs of addn by inputs_divisor_. + while (origin_input_size - cur_input_index + 1 >= inputs_divisor_) { base_addn_inputs.push_back(CreateNewAddn(func_graph, new_cnode, cur_input_index, inputs_divisor_)); cur_input_index += inputs_divisor_; } - base_addn_inputs.push_back( - CreateNewAddn(func_graph, new_cnode, cur_input_index, origin_input_size - cur_input_index + 1)); - + for (size_t i = cur_input_index; i <= origin_input_size; i++) { + base_addn_inputs.push_back(new_cnode->input(i)); + } CNodePtr base_addn = func_graph->NewCNode(base_addn_inputs); MS_EXCEPTION_IF_NULL(base_addn); - MS_EXCEPTION_IF_NULL(new_cnode); base_addn->set_scope(new_cnode->scope()); base_addn->set_abstract(new_cnode->abstract()); AnfAlgo::SetNodeAttr(kAttrN, MakeValue(SizeToInt(base_addn_inputs.size() - 1)), base_addn); + std::vector dyn_input_sizes{SizeToInt(base_addn_inputs.size() - 1)}; + AnfAlgo::SetNodeAttr(kAttrDynInputSizes, MakeValue(dyn_input_sizes), base_addn); new_cnode = base_addn; origin_input_size = base_addn->inputs().size() - 1; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 10ef4abf62..eac901b74d 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -149,7 +149,7 @@ constexpr auto kAttrDynInputSizes = "dyn_input_sizes"; constexpr auto kAttrSrcFormat = "src_format"; constexpr auto kAttrOutputUsedNum = "output_used_num"; constexpr auto kAttrHasBias = "has_bias"; -constexpr auto kAttrN = "N"; +constexpr auto kAttrN = "n"; constexpr auto kAttrLabelForInsertStreamActive = "label_for_insert_stream_active"; // attr value diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py index c120ac3e68..76d7e73a80 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/addn_fission_test.py @@ -45,13 +45,10 @@ def test_addn_fission(tag): b = addn((input2, input3)) c = addn((input4, input5)) d = addn((input6, input7)) - e = addn((input8,)) f = addn((a, b)) g = addn((c, d)) - h = addn((e,)) i = addn((f, g)) - j = addn((h,)) - return addn((i, j)) + return addn((i, input8)) @fns def after_divided_by_3(input0, input1, input2, input3, input4, input5, input6, input7, input8): @@ -64,14 +61,12 @@ def test_addn_fission(tag): def after_divided_by_4(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3)) b = addn((input4, input5, input6, input7)) - c = addn((input8,)) - return addn((a, b, c)) + return addn((a, b, input8)) @fns def after_divided_by_8(input0, input1, input2, input3, input4, input5, input6, input7, input8): a = addn((input0, input1, input2, input3, input4, input5, input6, input7)) - b = addn((input8,)) - return addn((a, b)) + return addn((a, input8)) @fns def after_divided_by_9(input0, input1, input2, input3, input4, input5, input6, input7, input8): From 7e23a1a475be056af0f1051f782712d641e6b6e9 Mon Sep 17 00:00:00 2001 From: fary86 Date: Mon, 20 Apr 2020 14:54:11 +0800 Subject: [PATCH 013/142] Fix issues of save_graphs_path, Type/Value error message and log file mode --- mindspore/ccsrc/utils/log_adapter.cc | 6 +++- mindspore/context.py | 28 ++++++++++++++++++- tests/ut/cpp/operator/composite_test.cc | 4 +-- tests/ut/python/pynative_mode/test_backend.py | 18 ++++++++++-- tests/ut/python/pynative_mode/test_context.py | 19 +++++++++++-- 5 files changed, 67 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index 704ab24d52..4c197a0bdf 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -179,7 +179,7 @@ void LogWriter::operator^(const LogStream &stream) const { std::ostringstream oss; oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] "; - if (exception_type_ != NoExceptionType) { + if (exception_type_ != NoExceptionType && exception_type_ != TypeError && exception_type_ != ValueError) { oss << ExceptionTypeToString(exception_type_) << " "; } oss << msg.str(); @@ -242,6 +242,10 @@ void mindspore_log_init(void) { if (mindspore::GetEnv("GLOG_v").empty()) { FLAGS_v = mindspore::WARNING; } + // set default log file mode to 0640 + if (mindspore::GetEnv("GLOG_logfile_mode").empty()) { + FLAGS_logfile_mode = 0640; + } // default print log to screen if (mindspore::GetEnv("GLOG_logtostderr").empty()) { FLAGS_logtostderr = true; diff --git a/mindspore/context.py b/mindspore/context.py index 2938b87119..fae7f7f762 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -16,6 +16,7 @@ The context of mindspore, used to configure the current execution environment, including execution mode, execution backend and other feature switchs. """ +import os import threading from collections import namedtuple from types import FunctionType @@ -33,6 +34,31 @@ GRAPH_MODE = 0 PYNATIVE_MODE = 1 +def _make_directory(path: str): + """Make directory.""" + real_path = None + if path is None or not isinstance(path, str) or path.strip() == "": + raise ValueError(f"Input path `{path}` is invaild type") + + # convert the relative paths + path = os.path.realpath(path) + logger.debug("The absolute path is %r", path) + + # check whether the path is already existed and has written permissions + if os.path.exists(path): + real_path = path + else: + # All exceptions need to be caught because create directory maybe have some limit(permissions) + logger.debug("The directory(%s) doesn't exist, will create it", path) + try: + os.makedirs(path) + real_path = path + except PermissionError as e: + logger.error(f"No write permission on the directory `{path}, error = {e}") + raise ValueError(f"No write permission on the directory `{path}`.") + return real_path + + class _ThreadLocalInfo(threading.local): """ Thread local Info used for store thread local attributes. @@ -173,7 +199,7 @@ class _Context: @save_graphs_path.setter def save_graphs_path(self, save_graphs_path): - self._context_handle.set_save_graphs_path(save_graphs_path) + self._context_handle.set_save_graphs_path(_make_directory(save_graphs_path)) @property def device_target(self): diff --git a/tests/ut/cpp/operator/composite_test.cc b/tests/ut/cpp/operator/composite_test.cc index d9dd9e5e99..2c4b9b7146 100644 --- a/tests/ut/cpp/operator/composite_test.cc +++ b/tests/ut/cpp/operator/composite_test.cc @@ -128,8 +128,8 @@ TEST_F(TestComposite, test_TupleSlice_arg_one_number) { trace::ClearTraceStack(); engine_->Run(tupleSliceGraphPtr, args_spec_list); FAIL() << "Excepted exception :Args type is wrong"; - } catch (std::runtime_error const &err) { - ASSERT_TRUE(std::string(err.what()).find("TypeError") != std::string::npos); + } catch (pybind11::type_error const &err) { + ASSERT_TRUE(true); } catch (...) { FAIL() << "Excepted exception :Args type is wrong"; } diff --git a/tests/ut/python/pynative_mode/test_backend.py b/tests/ut/python/pynative_mode/test_backend.py index 937f7b24ff..7258b69486 100644 --- a/tests/ut/python/pynative_mode/test_backend.py +++ b/tests/ut/python/pynative_mode/test_backend.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test_backend """ +import os import numpy as np import pytest from mindspore.ops import operations as P @@ -51,10 +52,11 @@ def test_vm_backend(): def test_vm_set_context(): """ test_vm_set_context """ - context.set_context(save_graphs=True, save_graphs_path="/home/mindspore", mode=context.GRAPH_MODE) + context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE) assert context.get_context("save_graphs") assert context.get_context("mode") == context.GRAPH_MODE - assert context.get_context("save_graphs_path") == "/home/mindspore" + assert os.path.exists("mindspore_ir_path") + assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 context.set_context(mode=context.PYNATIVE_MODE) @args_type_check(v_str=str, v_int=int, v_tuple=tuple) @@ -74,3 +76,15 @@ def test_args_type_check(): with pytest.raises(TypeError): check_input("name", 100, "age") check_input("name", 100, (10, 10)) + + +def teardown_module(): + dirs = ['mindspore_ir_path'] + for item in dirs: + item_name = './' + item + if not os.path.exists(item_name): + continue + if os.path.isdir(item_name): + os.rmdir(item_name) + elif os.path.isfile(item_name): + os.remove(item_name) diff --git a/tests/ut/python/pynative_mode/test_context.py b/tests/ut/python/pynative_mode/test_context.py index 450bf60b90..2425b53f42 100644 --- a/tests/ut/python/pynative_mode/test_context.py +++ b/tests/ut/python/pynative_mode/test_context.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """ test_context """ +import os import pytest from mindspore import context # pylint: disable=W0212 @@ -74,11 +75,12 @@ def test_dump_target(): def test_set_context(): """ test_set_context """ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", - device_id=0, save_graphs=True, save_graphs_path="/mindspore") + device_id=0, save_graphs=True, save_graphs_path="mindspore_ir_path") assert context.get_context("device_id") == 0 assert context.get_context("device_target") == "Ascend" assert context.get_context("save_graphs") - assert context.get_context("save_graphs_path") == "/mindspore" + assert os.path.exists("mindspore_ir_path") + assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 assert context.get_context("mode") == context.GRAPH_MODE context.set_context(mode=context.PYNATIVE_MODE) @@ -87,3 +89,16 @@ def test_set_context(): with pytest.raises(ValueError): context.set_context(modex="ge") + + +def teardown_module(): + dirs = ['mindspore_ir_path'] + for item in dirs: + item_name = './' + item + if not os.path.exists(item_name): + continue + if os.path.isdir(item_name): + os.rmdir(item_name) + elif os.path.isfile(item_name): + os.remove(item_name) + From c6b2b0df1ed74f6dcc4bd42a7a39cd17270b4629 Mon Sep 17 00:00:00 2001 From: lvliang Date: Sat, 18 Apr 2020 17:03:08 +0800 Subject: [PATCH 014/142] pynative-support-reducemean --- mindspore/ccsrc/pynative/base.h | 3 ++- mindspore/ops/operations/math_ops.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h index 7405f621cb..d8675adc9c 100644 --- a/mindspore/ccsrc/pynative/base.h +++ b/mindspore/ccsrc/pynative/base.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -59,7 +60,7 @@ struct OpExecInfo { using OpExecInfoPtr = std::shared_ptr; OpExecInfoPtr GenerateOpExecInfo(const py::args& args); -const std::unordered_set ignore_infer_prim = {"partial"}; +const std::set ignore_infer_prim = {"partial", "make_ref"}; } // namespace pynative } // namespace mindspore diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 98665dd27a..a3df6b7fba 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -24,7 +24,7 @@ from ..._checkparam import Rel from ...common import dtype as mstype from ...common.tensor import Tensor from .._utils import _get_broadcast_shape -from ..primitive import PrimitiveWithInfer, prim_attr_register +from ..primitive import PrimitiveWithInfer, prim_attr_register, _run_op def _infer_shape_reduce(x, axis, keep_dims, prim_name): @@ -225,6 +225,11 @@ class _Reduce(PrimitiveWithInfer): validator.check_value_type('keep_dims', keep_dims, [bool], self.name) self.init_prim_io_names(inputs=['input_x', 'axis'], outputs=['y']) + def __call__(self, x, axis=()): + args = [x, axis] + output = _run_op(self, self.name, args) + return output + def do_infer(self, input_x, axis, valid_dtype=mstype.number_type): axis_v = axis['value'] input_shp = input_x['shape'] From 4f0034353e4af3b2caf252f8c3ca2ada7f324df8 Mon Sep 17 00:00:00 2001 From: yanghaoran Date: Mon, 20 Apr 2020 17:01:10 +0800 Subject: [PATCH 015/142] add realpath to acquire real cmake dir despite usage of symbolic links --- cmake/dependency_graphengine.cmake | 2 +- graphengine | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/dependency_graphengine.cmake b/cmake/dependency_graphengine.cmake index 2a90cc1458..533f9f8246 100644 --- a/cmake/dependency_graphengine.cmake +++ b/cmake/dependency_graphengine.cmake @@ -64,7 +64,7 @@ set(_ge_tmp_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) string(REPLACE " -Wall" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") string(REPLACE " -Werror" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") # force __FILE__ to show relative path of file, from source directory -set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst ${CMAKE_SOURCE_DIR}/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__FILE__='\"$(subst $(realpath ${CMAKE_SOURCE_DIR})/,,$(abspath $<))\"' -Wno-builtin-macro-redefined") add_subdirectory(${GE_SOURCE_DIR}/src/common/graph) if(ENABLE_D) add_subdirectory(${GE_SOURCE_DIR}/src/ge/common) diff --git a/graphengine b/graphengine index 0c33e9d125..43f5d24337 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit 0c33e9d12562953ca4bd6c03cb77da2c2da74acd +Subproject commit 43f5d24337bf785251eefae2d810c7d5684194d6 From 4b2f546730bf90fa3f39097064399f40a28566d7 Mon Sep 17 00:00:00 2001 From: lupengcheng Date: Mon, 20 Apr 2020 17:14:47 +0800 Subject: [PATCH 016/142] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20?= =?UTF-8?q?!263=20:=20optimize=20cmake=20for=20tvm=20'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- cmake/external_libs/dmlc_core.cmake | 2 +- cmake/external_libs/tvm_gpu.cmake | 18 ++--- cmake/package.cmake | 12 +-- cmake/utils.cmake | 20 ++--- mindspore/ccsrc/CMakeLists.txt | 109 +++++++++++++++++++++++++++- 5 files changed, 122 insertions(+), 39 deletions(-) diff --git a/cmake/external_libs/dmlc_core.cmake b/cmake/external_libs/dmlc_core.cmake index e07df83fd6..386a52429d 100644 --- a/cmake/external_libs/dmlc_core.cmake +++ b/cmake/external_libs/dmlc_core.cmake @@ -1,4 +1,4 @@ -mindspore_add_pkg(dmlc-core +mindspore_add_pkg(dmlc_core VER 0.3 HEAD_ONLY ./ URL https://github.com/dmlc/dmlc-core/archive/808f485387f9a03f78fa9f1159f387d0d91b7a28.zip diff --git a/cmake/external_libs/tvm_gpu.cmake b/cmake/external_libs/tvm_gpu.cmake index 2edec52ee1..57a045cb03 100644 --- a/cmake/external_libs/tvm_gpu.cmake +++ b/cmake/external_libs/tvm_gpu.cmake @@ -1,16 +1,8 @@ -set(incubator_tvm_gpu_CFLAGS "-pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") -set(incubator_tvm_gpu_CXXFLAGS "-std=c++11 -pipe -Wall -fPIC -fstack-protector-all -D_FORTIFY_SOURCE=2 -O2") -set(USE_CUDA "ON") +set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") +set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") mindspore_add_pkg(incubator_tvm_gpu VER 0.6.0 - LIBS tvm + HEAD_ONLY ./ URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz - MD5 9cbbd32545a776023acabbba270449fe - SUBMODULES ${dlpack_DIRPATH} ${dmlc-core_DIRPATH} ${rang_DIRPATH} - SOURCEMODULES topi/python/topi python/tvm - PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/find_library.patch - ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch - ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch - CMAKE_OPTION -DBUILD_TESTING=OFF -DCMAKE_POSITION_INDEPENDENT_CODE=ON -DBUILD_SHARED_LIBS=ON) -include_directories(${incubator_tvm_gpu_INC}) -add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm) + MD5 9cbbd32545a776023acabbba270449fe) + diff --git a/cmake/package.cmake b/cmake/package.cmake index d35ce0463b..531dff29ca 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -191,17 +191,11 @@ if (ENABLE_GPU) DESTINATION ${INSTALL_PY_DIR}/../ COMPONENT mindspore ) - if (EXISTS ${incubator_tvm_gpu_ROOT}) - file(GLOB_RECURSE GLOG_LIB_LIST ${incubator_tvm_gpu_LIBPATH}/lib*) - install( - FILES ${GLOG_LIB_LIST} - DESTINATION ${INSTALL_LIB_DIR} - COMPONENT mindspore - ) + if (EXISTS ${CMAKE_BINARY_DIR}/incubator-tvm) install( DIRECTORY - ${incubator_tvm_gpu_ROOT}/topi/python/topi - ${incubator_tvm_gpu_ROOT}/python/tvm + ${CMAKE_BINARY_DIR}/incubator-tvm/topi/python/topi + ${CMAKE_BINARY_DIR}/incubator-tvm/python/tvm DESTINATION ${INSTALL_PY_DIR}/../_akg COMPONENT mindspore ) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 501522a44b..894a0de1b8 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -206,7 +206,7 @@ function(mindspore_add_pkg pkg_name ) set(options ) set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH) - set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES SUBMODULES SOURCEMODULES) + set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) if (NOT PKG_LIB_PATH) @@ -270,21 +270,11 @@ function(mindspore_add_pkg pkg_name ) endif () if (NOT PKG_DIR) - if (PKG_GIT_REPOSITORY) - __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5}) - else() + if (PKG_GIT_REPOSITORY) + __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5}) + else() __download_pkg(${pkg_name} ${PKG_URL} ${PKG_MD5}) - endif() - foreach(_SUBMODULE_FILE ${PKG_SUBMODULES}) - STRING( REGEX REPLACE "(.+)_(.+)" "\\1" _SUBMODEPATH ${_SUBMODULE_FILE}) - STRING( REGEX REPLACE "(.+)/(.+)" "\\2" _SUBMODENAME ${_SUBMODEPATH}) - file(GLOB ${pkg_name}_INSTALL_SUBMODULE ${_SUBMODULE_FILE}/*) - file(COPY ${${pkg_name}_INSTALL_SUBMODULE} DESTINATION ${${pkg_name}_SOURCE_DIR}/3rdparty/${_SUBMODENAME}) - endforeach (_SUBMODULE_FILE) - foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) - file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) - file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) - endforeach (_SUBMODULE_FILE) + endif() else() set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) endif () diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 8c33b9051c..9b615b0dad 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -395,7 +395,114 @@ if(USE_GLOG) endif() if(ENABLE_GPU) - target_link_libraries(_c_expression PRIVATE mindspore::tvm) + execute_process(COMMAND bash ${CMAKE_SOURCE_DIR}/third_party/apply_patches.sh + ${CMAKE_BINARY_DIR} + ${dlpack_DIRPATH} + ${dmlc_core_DIRPATH} + ${rang_DIRPATH} + ${incubator_tvm_gpu_DIRPATH}) + set(TVM_DIR "${CMAKE_BINARY_DIR}/incubator-tvm") + # Utility functions + include(${TVM_DIR}/cmake/util/Util.cmake) + include(${TVM_DIR}/cmake/util/FindCUDA.cmake) + + # include directories + include_directories(AFTER "${TVM_DIR}/include") + include_directories(AFTER "${TVM_DIR}/src") + include_directories(AFTER "${TVM_DIR}") + include_directories(AFTER "${TVM_DIR}/src/schedule") + + include_directories(AFTER "${TVM_DIR}/3rdparty/dmlc-core/include") + include_directories(AFTER "${TVM_DIR}/3rdparty/dlpack/include") + include_directories(AFTER "${TVM_DIR}/3rdparty/compiler-rt") + include_directories(AFTER "${TVM_DIR}/3rdparty/rang/include") + + # lib contain dlopen and dlclose + set(TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) + + # add source group + file(GLOB_RECURSE GROUP_SOURCE "${TVM_DIR}/src/*.cc" "src/*.cc") + file(GLOB_RECURSE GROUP_INCLUDE "${TVM_DIR}/src/*.h" + "${TVM_DIR}/include/*.h" "src/*.h" "include/*.h") + assign_source_group("Source" ${GROUP_SOURCE}) + assign_source_group("Include" ${GROUP_INCLUDE}) + + file(GLOB COMPILER_SRCS + "pre_activate/gpu/*.cc" + ${TVM_DIR}/src/api/*.cc + ${TVM_DIR}/src/arithmetic/*.cc + ${TVM_DIR}/src/autotvm/*.cc + ${TVM_DIR}/src/codegen/*.cc + ${TVM_DIR}/src/lang/*.cc + ${TVM_DIR}/src/pass/*.cc + ${TVM_DIR}/src/op/*.cc + ${TVM_DIR}/src/node/*.cc + ${TVM_DIR}/src/schedule/*.cc + ${TVM_DIR}/src/runtime/*.cc + ${TVM_DIR}/src/runtime/vm/*.cc + ${TVM_DIR}/src/runtime/vm/profiler/*.cc + ${TVM_DIR}/src/codegen/stackvm/*.cc) + + file(GLOB_RECURSE RELAY_SRCS ${TVM_DIR}/src/relay/*.cc) + list(APPEND COMPILER_SRCS ${RELAY_SRCS}) + + file(GLOB DATATYPE_SRCS ${TVM_DIR}/src/codegen/datatype/*.cc) + list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) + + file(GLOB COMPILER_VERILOG_SRCS ${TVM_DIR}/src/codegen/verilog/*.cc) + list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) + + file(GLOB TOPI_SRCS ${TVM_DIR}/topi/src/*.cc) + + file(GLOB RUNTIME_SRCS + ${TVM_DIR}/src/runtime/*.cc + ${TVM_DIR}/src/runtime/vm/*.cc + ${TVM_DIR}/src/runtime/stub/*.cc + ${TVM_DIR}/src/runtime/stackvm/*.cc) + + + file(GLOB COMPILER_OFF_SRCS + ${TVM_DIR}/src/codegen/opt/build_*_off.cc) + set(USE_CUDA "OFF") + if(ENABLE_GPU) + list(REMOVE_ITEM COMPILER_OFF_SRCS + ${TVM_DIR}/src/codegen/opt/build_cuda_off.cc) + set(USE_CUDA "ON") + endif() + list(APPEND COMPILER_SRCS ${COMPILER_OFF_SRCS}) + # Module rules + include(${TVM_DIR}/cmake/modules/CUDA.cmake) + + set(CMAKE_C_FLAGS_AKG -pipe -Wall -fPIC -fstack-protector-all) + set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) + + set(CMAKE_CXX_FLAGS_AKG -std=c++11 -pipe -Wall -fPIC -fstack-protector-all) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) + + if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + message("-- Build in Debug mode") + set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O0 -g -rdynamic) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O0 -g -rdynamic) + else() + message("-- Build in Release mode") + set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O2 -Werror) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O2 -Werror) + endif() + if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION + VERSION_GREATER 7.0) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -faligned-new) + endif() + + add_library(akg OBJECT ${COMPILER_SRCS} ${RUNTIME_SRCS} ${TOPI_SRCS}) + + target_link_libraries(akg ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) + target_compile_options(akg PRIVATE + $<$:${CMAKE_C_FLAGS_AKG}> + $<$:${CMAKE_CXX_FLAGS_AKG}>) + target_include_directories(akg PRIVATE "${TVM_DIR}/topi/include") + + add_dependencies(_c_expression akg) + target_link_libraries(_c_expression PRIVATE akg) endif() if(ENABLE_DUMP_PROTO) From c000fb2f34c6604e38606bb6b61fe94404350016 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 14 Apr 2020 19:47:24 +0800 Subject: [PATCH 017/142] gpu add float_status kernel --- .../kernel/gpu/cuda_impl/float_status_impl.cu | 138 ++++++++++++++++++ .../gpu/cuda_impl/float_status_impl.cuh | 28 ++++ .../gpu/math/float_status_gpu_kernel.cc | 38 +++++ .../kernel/gpu/math/float_status_gpu_kernel.h | 130 +++++++++++++++++ tests/st/ops/gpu/test_float_status_op.py | 118 +++++++++++++++ 5 files changed, 452 insertions(+) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_float_status_op.py diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu new file mode 100644 index 0000000000..c2fd5ecd70 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cu @@ -0,0 +1,138 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "include/cuda_runtime.h" +#include "kernel/gpu/cuda_impl/float_status_impl.cuh" + +template +__global__ void IsNan(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsNan(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsInf(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsInf(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void IsFinite(const size_t size, const T* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) == 0 && !isnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} +template <> +__global__ void IsFinite(const size_t size, const half* input, bool* out) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) == 0 && !__hisnan(input[pos])) { + out[pos] = true; + } else { + out[pos] = false; + } + } + return; +} + +template +__global__ void FloatStatus(const size_t size, const T* input, T* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (isinf(input[pos]) != 0 || isnan(input[pos])) { + out[0] = 1; + } + } + return; +} +template <> +__global__ void FloatStatus(const size_t size, const half* input, half* out) { + out[0] = 0; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if (__hisinf(input[pos]) != 0 || __hisnan(input[pos])) { + out[0] = 1; + } + } + return; +} + +template +void CalFloatStatus(const size_t size, const T* input, T* output, cudaStream_t cuda_stream) { + FloatStatus<<>>(size, input, output); + return; +} +template +void CalIsNan(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsNan<<>>(size, input, output); + return; +} +template +void CalIsInf(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsInf<<>>(size, input, output); + return; +} +template +void CalIsFinite(const size_t size, const T* input, bool* output, cudaStream_t cuda_stream) { + IsFinite<<>>(size, input, output); + return; +} + +template void CalFloatStatus(const size_t size, const float* input, float* output, cudaStream_t cuda_stream); +template void CalFloatStatus(const size_t size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsInf(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsNan(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const float* input, bool* output, cudaStream_t cuda_stream); +template void CalIsFinite(const size_t size, const half* input, bool* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh new file mode 100644 index 0000000000..da488ff937 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/float_status_impl.cuh @@ -0,0 +1,28 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ +#include "device/gpu/cuda_common.h" +template +void CalFloatStatus(const size_t size, const T *input, T *output, cudaStream_t stream); +template +void CalIsNan(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsInf(const size_t size, const T *input, bool *output, cudaStream_t stream); +template +void CalIsFinite(const size_t size, const T *input, bool *output, cudaStream_t stream); +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_FLOATSTATUS_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc new file mode 100644 index 0000000000..374644eaf5 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/math/float_status_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(FloatStatus, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsInf, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsNan, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(IsFinite, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), + FloatStatusGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h new file mode 100644 index 0000000000..bdd93d5d54 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/math/float_status_gpu_kernel.h @@ -0,0 +1,130 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H +#define MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H + +#include +#include +#include +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/float_status_impl.cuh" + +namespace mindspore { +namespace kernel { +enum Optype { OP_STATUS = 0, OP_INF, OP_NAN, OP_FINITE, OP_INVALID = 255 }; +static const std::map kOpTypeMap = { + {"FloatStatus", OP_STATUS}, {"IsInf", OP_INF}, {"IsNan", OP_NAN}, {"IsFinite", OP_FINITE}}; +template +class FloatStatusGpuKernel : public GpuKernel { + public: + FloatStatusGpuKernel() : kernel_name_(OP_INVALID), input_size_(0), output_size_(0) {} + ~FloatStatusGpuKernel() override = default; + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream_ptr) override { + T *input = GetDeviceAddress(inputs, 0); + + switch (kernel_name_) { + case OP_STATUS: { + T *output = GetDeviceAddress(outputs, 0); + CalFloatStatus(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_INF: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsInf(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_NAN: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsNan(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + case OP_FINITE: { + bool *output = GetDeviceAddress(outputs, 0); + CalIsFinite(input_size_ / sizeof(T), input, output, reinterpret_cast(stream_ptr)); + break; + } + default: { + MS_LOG(EXCEPTION) << "FloatStatus type " << kernel_name_ << " is not supported."; + } + } + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + if (!CheckParam(kernel_node)) { + return false; + } + auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); + for (size_t x : shape) { + input_size_ = input_size_ * x; + } + auto kernel_name = AnfAlgo::GetCNodeName(kernel_node); + auto iter = kOpTypeMap.find(kernel_name); + if (iter == kOpTypeMap.end()) { + MS_LOG(EXCEPTION) << "FloatStatus kernel " << kernel_name << " is not supported."; + } else { + kernel_name_ = iter->second; + } + if (kernel_name_ == OP_STATUS) { + output_size_ = sizeof(T); + } else { + output_size_ = input_size_ / sizeof(T) * sizeof(bool); + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + } + + private: + bool CheckParam(const CNodePtr &kernel_node) { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but FloatStatusGpuKernel needs 1 output."; + return false; + } + return true; + } + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; + Optype kernel_name_; + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_FLOAT_STATUS_GPU_KERNEL_H diff --git a/tests/st/ops/gpu/test_float_status_op.py b/tests/st/ops/gpu/test_float_status_op.py new file mode 100644 index 0000000000..09fc90feaa --- /dev/null +++ b/tests/st/ops/gpu/test_float_status_op.py @@ -0,0 +1,118 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.status = P.FloatStatus() + + def construct(self, x): + return self.status(x) + +class Netnan(nn.Cell): + def __init__(self): + super(Netnan, self).__init__() + self.isnan = P.IsNan() + + def construct(self, x): + return self.isnan(x) + +class Netinf(nn.Cell): + def __init__(self): + super(Netinf, self).__init__() + self.isinf = P.IsInf() + + def construct(self, x): + return self.isinf(x) + +class Netfinite(nn.Cell): + def __init__(self): + super(Netfinite, self).__init__() + self.isfinite = P.IsFinite() + + def construct(self, x): + return self.isfinite(x) + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") +x1 = np.array([[1.2, 2, np.nan, 88]]).astype(np.float32) +x2 = np.array([[np.inf, 1, 88.0, 0]]).astype(np.float32) +x3 = np.array([[1, 2], [3, 4], [5.0, 88.0]]).astype(np.float32) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_status(): + ms_status = Net(); + output1 = ms_status(Tensor(x1)) + output2 = ms_status(Tensor(x2)) + output3 = ms_status(Tensor(x3)) + expect1 = 1 + expect2 = 1 + expect3 = 0 + assert output1.asnumpy()[0] == expect1 + assert output2.asnumpy()[0] == expect2 + assert output3.asnumpy()[0] == expect3 + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_nan(): + ms_isnan = Netnan(); + output1 = ms_isnan(Tensor(x1)) + output2 = ms_isnan(Tensor(x2)) + output3 = ms_isnan(Tensor(x3)) + expect1 = [[False, False, True, False]] + expect2 = [[False, False, False, False]] + expect3 = [[False, False], [False, False], [False, False]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_inf(): + ms_isinf = Netinf(); + output1 = ms_isinf(Tensor(x1)) + output2 = ms_isinf(Tensor(x2)) + output3 = ms_isinf(Tensor(x3)) + expect1 = [[False, False, False, False]] + expect2 = [[True, False, False, False]] + expect3 = [[False, False], [False, False], [False, False]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_finite(): + ms_isfinite = Netfinite(); + output1 = ms_isfinite(Tensor(x1)) + output2 = ms_isfinite(Tensor(x2)) + output3 = ms_isfinite(Tensor(x3)) + expect1 = [[True, True, False, True]] + expect2 = [[False, True, True, True]] + expect3 = [[True, True], [True, True], [True, True]] + assert (output1.asnumpy() == expect1).all() + assert (output2.asnumpy() == expect2).all() + assert (output3.asnumpy() == expect3).all() From c5cfb09e6683c7145fe2534d14efdabbc0274a70 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Mon, 20 Apr 2020 19:45:46 +0800 Subject: [PATCH 018/142] Repair some MS_LOG problem --- mindspore/ccsrc/dataset/engine/datasetops/take_op.cc | 10 ++++++---- tests/ut/cpp/dataset/take_op_test.cc | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc index d9625b6c26..5d7df58153 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/take_op.cc @@ -67,7 +67,7 @@ Status TakeOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work bool last_repeat = !BitTest(op_ctrl_flags_, kDeOpRepeated) || BitTest(op_ctrl_flags_, kDeOpLastRepeat); if (take_count_ == max_takes_) { if (state_ == OpState::kDeOpRunning) { - MS_LOG(INFO) << "meet max count and push-back eoe buffer."; + MS_LOG(DEBUG) << "Meet max count and push-back eoe buffer."; auto eoe_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOE); *p_buffer = std::move(eoe_buffer); state_ = OpState::kDeOpIdle; @@ -80,11 +80,13 @@ Status TakeOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); } } - } else { - MS_LOG(INFO) << "meet max count and push-back eof buffer."; + } else if (state_ == OpState::kDeOpIdle) { + MS_LOG(DEBUG) << "Meet max count and push-back eof buffer."; auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); *p_buffer = std::move(eof_buffer); take_count_ = 0; + } else { + MS_LOG(WARNING) << "Invalid OpState: " << state_; } return Status::OK(); } @@ -116,7 +118,7 @@ Status TakeOp::FillBuffer(std::unique_ptr *buffer, std::unique_ptr new_tensor_table = std::make_unique(); while (take_count_ < max_takes_) { TensorRow new_row; diff --git a/tests/ut/cpp/dataset/take_op_test.cc b/tests/ut/cpp/dataset/take_op_test.cc index 7f8508de20..b7be066d6c 100644 --- a/tests/ut/cpp/dataset/take_op_test.cc +++ b/tests/ut/cpp/dataset/take_op_test.cc @@ -69,7 +69,7 @@ TEST_F(MindDataTestTakeOp, TestTakeProject) { rc = my_tree->AssignRoot(my_take_op); ASSERT_TRUE(rc.IsOk()); - MS_LOG(INFO) << "Launching tree and begin iteration."; + MS_LOG(DEBUG) << "Launching tree and begin iteration."; rc = my_tree->Prepare(); ASSERT_TRUE(rc.IsOk()); @@ -85,13 +85,13 @@ TEST_F(MindDataTestTakeOp, TestTakeProject) { int row_count = 0; while (!tensor_list.empty()) { - MS_LOG(INFO) << "Row display for row #: " << row_count << "."; + MS_LOG(DEBUG) << "Row display for row #: " << row_count << "."; // Display the tensor by calling the printer on it for (int i = 0; i < tensor_list.size(); i++) { std::ostringstream ss; ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; - MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; + MS_LOG(DEBUG) << "Tensor print: " << ss.str() << "."; } rc = di.FetchNextTensorRow(&tensor_list); From 7c233a57fa285842054612eb83d612adb7c05e96 Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 20 Apr 2020 17:40:46 +0800 Subject: [PATCH 019/142] support python func print and != for list with none --- mindspore/_extends/parse/resources.py | 1 + mindspore/_extends/parse/trope.py | 4 +- .../composite/multitype_ops/not_equal_impl.py | 37 +++++++++++++++++-- mindspore/ops/functional.py | 2 +- mindspore/ops/operations/_grad_ops.py | 1 + .../ut/python/pipeline/parse/test_operator.py | 6 ++- tests/vm_impl/nn_ops_vm_impl.py | 2 - 7 files changed, 43 insertions(+), 10 deletions(-) diff --git a/mindspore/_extends/parse/resources.py b/mindspore/_extends/parse/resources.py index c2c2716697..7178cd2634 100644 --- a/mindspore/_extends/parse/resources.py +++ b/mindspore/_extends/parse/resources.py @@ -114,6 +114,7 @@ convert_object_map = { T.map: C.HyperMap(), T.partial: F.partial, T.zip: C.zip_operation, + T.print: F.print_, # custom define operation T.iter: M.ms_iter, diff --git a/mindspore/_extends/parse/trope.py b/mindspore/_extends/parse/trope.py index 9f8f67fba5..7b40adcd16 100644 --- a/mindspore/_extends/parse/trope.py +++ b/mindspore/_extends/parse/trope.py @@ -27,7 +27,7 @@ from operator import ( # noqa # support system function call from builtins import ( # noqa - bool, getattr, setattr, len, iter, next, pow, range, map, zip + bool, getattr, setattr, len, iter, next, pow, range, map, zip, print ) # support functools @@ -44,7 +44,7 @@ __all__ = ['add', 'sub', 'mul', 'truediv', 'floordiv', 'mod', 'eq', 'ne', 'lt', 'not_', 'and_', 'or_', 'xor', 'lshift', 'rshift', 'invert', 'is_', 'is_not', 'contains', 'matmul', 'getitem', 'setitem', 'bool', 'getattr', 'setattr', 'len', 'iter', 'next', 'pow', 'range', 'map', 'zip', - 'partial', + 'partial', 'print', 'exp', 'log', 'sin', 'cos', 'tan'] diff --git a/mindspore/ops/composite/multitype_ops/not_equal_impl.py b/mindspore/ops/composite/multitype_ops/not_equal_impl.py index de099a2b8f..7196f370cb 100644 --- a/mindspore/ops/composite/multitype_ops/not_equal_impl.py +++ b/mindspore/ops/composite/multitype_ops/not_equal_impl.py @@ -132,7 +132,7 @@ def _none_not_equal_scalar(x, y): @not_equal.register("Tuple", "Tuple") -def _euqal_tuple(x, y): +def _not_euqal_tuple(x, y): """ Determine if two tuples are not equal by element. @@ -147,7 +147,7 @@ def _euqal_tuple(x, y): @not_equal.register("List", "List") -def _euqal_list(x, y): +def _not_euqal_list(x, y): """ Determine if two lists are not equal by element. @@ -162,7 +162,7 @@ def _euqal_list(x, y): @not_equal.register("Tuple", "None") -def _tuple_euqal_none(x, y): +def _tuple_not_euqal_none(x, y): """ Determine if tuple element not equals none element. @@ -190,6 +190,7 @@ def _none_not_equal_tuple(x, y): """ return True + @not_equal.register("Tensor", "Number") @not_equal.register("Number", "Tensor") @not_equal.register("Tensor", "Tensor") @@ -235,3 +236,33 @@ def _none_not_equal_tensor(x, y): bool, return True. """ return True + + +@not_equal.register("List", "None") +def _list_not_equal_none(x, y): + """ + Determine if list not equal none. + + Args: + x (list): The first input which is a list. + y (none): The second input which is none. + + Returns: + bool, return true. + """ + return True + + +@not_equal.register("None", "List") +def _none_not_equal_list(x, y): + """ + Determine if none not equal list. + + Args: + x (none): The first input which is none. + y (list): The second input which is a list. + + Returns: + bool, return true. + """ + return True diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index 0ed750beb1..d94ef3a11c 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -66,7 +66,7 @@ scalar_to_array = P.ScalarToArray() scalar_to_tensor = P.ScalarToTensor() tuple_to_array = P.TupleToArray() scalar_cast = P.ScalarCast() - +print_ = P.Print() tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 48d1a2a89c..9670ddd86c 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -108,6 +108,7 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) return x_type + class ConcatOffset(PrimitiveWithInfer): """primitive for computing Concat's gradient.""" diff --git a/tests/ut/python/pipeline/parse/test_operator.py b/tests/ut/python/pipeline/parse/test_operator.py index a3c5f7e422..6ae02fa96b 100644 --- a/tests/ut/python/pipeline/parse/test_operator.py +++ b/tests/ut/python/pipeline/parse/test_operator.py @@ -160,8 +160,10 @@ def test_ops(): ret_floor = p // q + q // p ret = ret_pow + ret_mod + ret_floor if self.int > self.float: - if self.str_a + self.str_b == "helloworld": - return ret + if [1, 2, 3] != None: + if self.str_a + self.str_b == "helloworld": + print("hello world") + return ret return x net = OpsNet(9, 2) diff --git a/tests/vm_impl/nn_ops_vm_impl.py b/tests/vm_impl/nn_ops_vm_impl.py index fc1fa95024..8794acbbd2 100644 --- a/tests/vm_impl/nn_ops_vm_impl.py +++ b/tests/vm_impl/nn_ops_vm_impl.py @@ -151,8 +151,6 @@ def vm_impl_max_pool_grad_with_argmax(self): """Generate vm_impl function for MaxPoolGradWithArgmax""" def vm_impl(x, dout, argmax): - print("buxue") - print(argmax) x = x.asnumpy() dout = dout.asnumpy() arg_max = argmax.asnumpy() From 7b99a1cb2a8cffa4ee7dd4f0c6c42551b38a3929 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Mon, 20 Apr 2020 20:10:21 +0800 Subject: [PATCH 020/142] fix bug in model predict and eval --- mindspore/train/model.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 46e4f421f7..3391cc7f3b 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -108,6 +108,7 @@ class Model: self._train_network = self._build_train_network() self._build_eval_network(metrics, eval_network, eval_indexes) + self._build_predict_network() def _check_kwargs(self, kwargs): for arg in kwargs: @@ -153,6 +154,12 @@ class Model: self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) self._eval_indexes = [0, 1, 2] + def _build_predict_network(self): + """Build the network for prediction.""" + self._predict_network = self._network + if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): + self._predict_network = _VirtualDatasetCell(self._network) + def _clear_metrics(self): """Clear metrics local values.""" for metric in self._metric_fns.values(): @@ -466,6 +473,7 @@ class Model: dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) for next_element in dataset_helper: + cb_params.cur_step_num += 1 list_callback.step_begin(run_context) outputs = self._eval_network(*next_element) cb_params.net_outputs = outputs @@ -543,12 +551,9 @@ class Model: >>> model = Model(Net()) >>> model.predict(input_data) """ - if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): - self._network = _VirtualDatasetCell(self._network) - - self._network.set_train(False) + self._predict_network.set_train(False) check_input_data(*predict_data, data_class=Tensor) - result = self._network(*predict_data) + result = self._predict_network(*predict_data) check_output_data(result) return result From 4f8fa79f33f1db91a1d68db0b8c3cf5cf456ac26 Mon Sep 17 00:00:00 2001 From: buxue Date: Mon, 20 Apr 2020 20:52:05 +0800 Subject: [PATCH 021/142] fix attribute mapping when docking open source operators --- mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py | 2 +- mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py | 2 +- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py | 2 +- mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py b/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py index e32e99d888..04b55bb2a3 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py +++ b/mindspore/ops/_op_impl/tbe/conv2d_backprop_filter.py @@ -25,7 +25,7 @@ conv2d_backprop_filter_op_info = TBERegOp("Conv2DBackpropFilter") \ .partial_flag(True) \ .attr("filter_sizes", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ - .attr("pad_mode", "required", "str", "all") \ + .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ .input(0, "out_backprop", False, "required", "all") \ .input(1, "x", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py b/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py index 2c1dd6aea2..7756cb3ae4 100644 --- a/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py +++ b/mindspore/ops/_op_impl/tbe/conv2d_backprop_input.py @@ -25,7 +25,7 @@ conv2d_backprop_input_op_info = TBERegOp("Conv2DBackpropInput") \ .partial_flag(True) \ .attr("input_sizes", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ - .attr("pad_mode", "required", "str", "all") \ + .attr("pad_list", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ .input(0, "out_backprop", False, "required", "all") \ .input(1, "filter", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py index c19a311009..f4d8069b12 100644 --- a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py +++ b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_filter.py @@ -26,7 +26,7 @@ depthwise_conv2d_backprop_filter_op_info = TBERegOp("DepthwiseConv2dNativeBackpr .attr("filter_size", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ - .attr("pads", "required", "str", "all") \ + .attr("pads", "required", "listInt", "all") \ .attr("data_format", "required", "str", "all") \ .input(0, "input", False, "required", "all") \ .input(1, "out_backprop", False, "required", "all") \ diff --git a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py index 9e671f18e2..61c1406b32 100644 --- a/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py +++ b/mindspore/ops/_op_impl/tbe/depthwise_conv2d_backprop_input.py @@ -26,7 +26,7 @@ depthwise_conv2d_backprop_input_op_info = TBERegOp("DepthwiseConv2dNativeBackpro .attr("input_size", "required", "listInt", "all") \ .attr("stride", "required", "listInt", "all") \ .attr("dilation", "required", "listInt", "all") \ - .attr("pads", "required", "str", "all") \ + .attr("pads", "required", "listInt", "all") \ .attr("data_format", "required", "str", "all") \ .input(0, "filter", False, "required", "all") \ .input(1, "out_backprop", False, "required", "all") \ From 9bc2134cb70cb4f72290f6919829ce4b86de3d58 Mon Sep 17 00:00:00 2001 From: Peilin Wang Date: Thu, 16 Apr 2020 18:16:47 -0400 Subject: [PATCH 022/142] added checking of first row crc to find invalid tfrecord files addressed code review comments. added check in python layer to exclude directories and to raise an error if a pattern does not match any file fixed clang format fixed cppcheck fixed cppcheck (used std::accumulate and std::copy_if). regenerated tfrecord file to contain correct header, it was a dummy header before fixed cppcheck: added const reference for string parameter for lambdas, fixed clang format: whitespace adjustments more clang whitespace fixes... changed print to logger.info --- .../engine/datasetops/source/tf_reader_op.cc | 57 ++++++++++++++++-- mindspore/dataset/engine/datasets.py | 17 ++++-- tests/ut/cpp/dataset/tfReader_op_test.cc | 34 +++++++++++ .../dataset/testTFBert5Rows/5TFDatas.data | Bin 3865 -> 3865 bytes .../dataset/testTFBert5Rows1/5TFDatas.data | Bin 3865 -> 3865 bytes .../dataset/testTFBert5Rows2/5TFDatas.data | Bin 3865 -> 3865 bytes .../testTFTestAllTypes/invalidFile.txt | 1 + tests/ut/python/dataset/test_tfreader_op.py | 29 ++++++++- 8 files changed, 127 insertions(+), 11 deletions(-) create mode 100644 tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index 0764d7e0ad..a72be1f703 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -42,6 +42,7 @@ #include "dataset/util/status.h" #include "dataset/util/task_manager.h" #include "dataset/util/wait_post.h" +#include "utils/system/crc32c.h" namespace mindspore { namespace dataset { @@ -56,15 +57,58 @@ TFReaderOp::Builder::Builder() builder_data_schema_ = std::make_unique(); } +bool ValidateFirstRowCrc(const std::string &filename) { + std::ifstream reader; + reader.open(filename); + if (!reader) { + return false; + } + + // read data + int64_t record_length = 0; + (void)reader.read(reinterpret_cast(&record_length), static_cast(sizeof(int64_t))); + + // read crc from file + uint32_t masked_crc = 0; + (void)reader.read(reinterpret_cast(&masked_crc), static_cast(sizeof(uint32_t))); + + // generate crc from data + uint32_t generated_crc = + system::Crc32c::GetMaskCrc32cValue(reinterpret_cast(&record_length), sizeof(int64_t)); + + return masked_crc == generated_crc; +} + Status TFReaderOp::Builder::ValidateInputs() const { std::string err_msg; - err_msg += builder_num_workers_ <= 0 ? "Number of parallel workers is smaller or equal to 0\n" : ""; - if (!builder_equal_rows_per_shard_) { - err_msg += builder_dataset_files_list_.size() < static_cast(builder_num_devices_) - ? "No enough tf_file files provided\n" - : ""; + + if (builder_num_workers_ <= 0) { + err_msg += "Number of parallel workers is smaller or equal to 0\n"; + } + + if (!builder_equal_rows_per_shard_ && + builder_dataset_files_list_.size() < static_cast(builder_num_devices_)) { + err_msg += "Not enough tfrecord files provided\n"; + } + + if (builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1) { + err_msg += "Wrong sharding configs\n"; } - err_msg += builder_device_id_ >= builder_num_devices_ || builder_num_devices_ < 1 ? "Wrong sharding configs\n" : ""; + + std::vector invalid_files(builder_dataset_files_list_.size()); + auto it = std::copy_if(builder_dataset_files_list_.begin(), builder_dataset_files_list_.end(), invalid_files.begin(), + [](const std::string &filename) { return !ValidateFirstRowCrc(filename); }); + invalid_files.resize(std::distance(invalid_files.begin(), it)); + + if (!invalid_files.empty()) { + err_msg += "The following files either cannot be opened, or are not valid tfrecord files:\n"; + + std::string accumulated_filenames = std::accumulate( + invalid_files.begin(), invalid_files.end(), std::string(""), + [](const std::string &accumulated, const std::string &next) { return accumulated + " " + next + "\n"; }); + err_msg += accumulated_filenames; + } + return err_msg.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, err_msg); } @@ -523,6 +567,7 @@ Status TFReaderOp::LoadFile(const std::string &filename, const int64_t start_off RETURN_IF_NOT_OK(LoadExample(&tf_file, &new_tensor_table, rows_read)); rows_read++; } + // ignore crc footer (void)reader.ignore(static_cast(sizeof(int32_t))); rows_total++; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ca717643c9..593a5c39a0 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -900,13 +900,22 @@ class SourceDataset(Dataset): List, files. """ - def flat(lists): - return list(np.array(lists).flatten()) - if not isinstance(patterns, list): patterns = [patterns] - file_list = flat([glob.glob(file, recursive=True) for file in patterns]) + file_list = [] + unmatched_patterns = [] + for pattern in patterns: + matches = [match for match in glob.glob(pattern, recursive=True) if os.path.isfile(match)] + + if matches: + file_list.extend(matches) + else: + unmatched_patterns.append(pattern) + + if unmatched_patterns: + raise ValueError("The following patterns did not match any files: ", unmatched_patterns) + if file_list: # not empty return file_list raise ValueError("The list of path names matching the patterns is empty.") diff --git a/tests/ut/cpp/dataset/tfReader_op_test.cc b/tests/ut/cpp/dataset/tfReader_op_test.cc index 5fb1f4e909..9b312296d8 100644 --- a/tests/ut/cpp/dataset/tfReader_op_test.cc +++ b/tests/ut/cpp/dataset/tfReader_op_test.cc @@ -697,3 +697,37 @@ TEST_F(MindDataTestTFReaderOp, TestTotalRowsBasic) { TFReaderOp::CountTotalRows(&total_rows, filenames, 729, true); ASSERT_EQ(total_rows, 60); } + +TEST_F(MindDataTestTFReaderOp, TestTFReaderInvalidFiles) { + // Start with an empty execution tree + auto my_tree = std::make_shared(); + + std::string valid_file = datasets_root_path_ + "/testTFTestAllTypes/test.data"; + std::string schema_file = datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json"; + std::string invalid_file = datasets_root_path_ + "/testTFTestAllTypes/invalidFile.txt"; + std::string nonexistent_file = "this/file/doesnt/exist"; + + std::shared_ptr my_tfreader_op; + TFReaderOp::Builder builder; + builder.SetDatasetFilesList({invalid_file, valid_file, schema_file}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16); + + std::unique_ptr schema = std::make_unique(); + schema->LoadSchemaFile(schema_file, {}); + builder.SetDataSchema(std::move(schema)); + + Status rc = builder.Build(&my_tfreader_op); + ASSERT_TRUE(!rc.IsOk()); + + builder.SetDatasetFilesList({invalid_file, valid_file, schema_file, nonexistent_file}) + .SetRowsPerBuffer(16) + .SetNumWorkers(16); + + schema = std::make_unique(); + schema->LoadSchemaFile(schema_file, {}); + builder.SetDataSchema(std::move(schema)); + + rc = builder.Build(&my_tfreader_op); + ASSERT_TRUE(!rc.IsOk()); +} diff --git a/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows/5TFDatas.data index c5b5440cffe15e2630d4e6866942fcc4dd5d69d7..f3bb23af5112f2f78c0b70118e0843e1d7769e12 100644 GIT binary patch delta 146 zcmbO!H&c%1E)xS7#4Qfd-U#Fg%v-Lw;TKep$1(TK=ATTzn1odxEiTyz6%^#Umm?7p^NS}YiwUBM{TxOZgk2zSGh5ygk#wH7ujFK8ktyRWD_(2wQ48B delta 146 zcmbO!H&c%1E)xS7WB|cNAWt9-2!0_6ZvM&ii%A$Hx)Di~YaQ!kc|M!XPHaUGNfl;< vIa~|bCeLM-+5DJ;l^Lv47Flu@NYZ`s3ucMQ4y-DZSj09@;b~-I0oecmAy_8? diff --git a/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows1/5TFDatas.data index c5b5440cffe15e2630d4e6866942fcc4dd5d69d7..f3bb23af5112f2f78c0b70118e0843e1d7769e12 100644 GIT binary patch delta 146 zcmbO!H&c%1E)xS7#4Qfd-U#Fg%v-Lw;TKep$1(TK=ATTzn1odxEiTyz6%^#Umm?7p^NS}YiwUBM{TxOZgk2zSGh5ygk#wH7ujFK8ktyRWD_(2wQ48B delta 146 zcmbO!H&c%1E)xS7WB|cNAWt9-2!0_6ZvM&ii%A$Hx)Di~YaQ!kc|M!XPHaUGNfl;< vIa~|bCeLM-+5DJ;l^Lv47Flu@NYZ`s3ucMQ4y-DZSj09@;b~-I0oecmAy_8? diff --git a/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data b/tests/ut/data/dataset/testTFBert5Rows2/5TFDatas.data index c5b5440cffe15e2630d4e6866942fcc4dd5d69d7..f3bb23af5112f2f78c0b70118e0843e1d7769e12 100644 GIT binary patch delta 146 zcmbO!H&c%1E)xS7#4Qfd-U#Fg%v-Lw;TKep$1(TK=ATTzn1odxEiTyz6%^#Umm?7p^NS}YiwUBM{TxOZgk2zSGh5ygk#wH7ujFK8ktyRWD_(2wQ48B delta 146 zcmbO!H&c%1E)xS7WB|cNAWt9-2!0_6ZvM&ii%A$Hx)Di~YaQ!kc|M!XPHaUGNfl;< vIa~|bCeLM-+5DJ;l^Lv47Flu@NYZ`s3ucMQ4y-DZSj09@;b~-I0oecmAy_8? diff --git a/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt b/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt new file mode 100644 index 0000000000..3307b71672 --- /dev/null +++ b/tests/ut/data/dataset/testTFTestAllTypes/invalidFile.txt @@ -0,0 +1 @@ +this is just a text file, not a valid tfrecord file. diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py index 6de14df34e..3add50e1cb 100644 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ b/tests/ut/python/dataset/test_tfreader_op.py @@ -32,7 +32,7 @@ def test_case_tf_shape(): ds1 = ds.TFRecordDataset(FILES, schema_file) ds1 = ds1.batch(2) for data in ds1.create_dict_iterator(): - print(data) + logger.info(data) output_shape = ds1.output_shapes() assert (len(output_shape[-1]) == 1) @@ -203,6 +203,32 @@ def test_tf_record_schema_columns_list(): a = row["col_sint32"] assert "col_sint32" in str(info.value) +def test_case_invalid_files(): + valid_file = "../data/dataset/testTFTestAllTypes/test.data" + invalid_file = "../data/dataset/testTFTestAllTypes/invalidFile.txt" + files = [invalid_file, valid_file, SCHEMA_FILE] + + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + + with pytest.raises(RuntimeError) as info: + row = data.create_dict_iterator().get_next() + assert "cannot be opened" in str(info.value) + assert "not valid tfrecord files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file in str(info.value) + assert SCHEMA_FILE in str(info.value) + + nonexistent_file = "this/file/does/not/exist" + files = [invalid_file, valid_file, SCHEMA_FILE, nonexistent_file] + + with pytest.raises(ValueError) as info: + data = ds.TFRecordDataset(files, SCHEMA_FILE, shuffle=ds.Shuffle.FILES) + assert "did not match any files" in str(info.value) + assert valid_file not in str(info.value) + assert invalid_file not in str(info.value) + assert SCHEMA_FILE not in str(info.value) + assert nonexistent_file in str(info.value) + if __name__ == '__main__': test_case_tf_shape() test_case_tf_file() @@ -212,3 +238,4 @@ if __name__ == '__main__': test_tf_record_schema() test_tf_record_shuffle() test_tf_shard_equal_rows() + test_case_invalid_files() From c705ea5e5be06f5f30a0c546b207b9c6798ee5d6 Mon Sep 17 00:00:00 2001 From: xulei2020 <“xulei83@huawei.com”> Date: Sat, 18 Apr 2020 15:19:06 +0800 Subject: [PATCH 023/142] add filterOp code --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 39 +- mindspore/ccsrc/dataset/api/de_pipeline.h | 4 +- mindspore/ccsrc/dataset/core/client.h | 1 + mindspore/ccsrc/dataset/core/tensor.cc | 2 +- .../dataset/engine/datasetops/CMakeLists.txt | 1 + .../dataset/engine/datasetops/filter_op.cc | 273 ++++++++++ .../dataset/engine/datasetops/filter_op.h | 180 +++++++ mindspore/dataset/engine/datasets.py | 66 ++- mindspore/dataset/engine/iterators.py | 2 + mindspore/dataset/engine/validators.py | 20 + tests/ut/cpp/dataset/CMakeLists.txt | 2 + tests/ut/cpp/dataset/filter_op_test.cc | 53 ++ tests/ut/cpp/dataset/tensor_test.cc | 10 + tests/ut/data/dataset/declient_filter.cfg | 3 + tests/ut/python/dataset/test_filterop.py | 504 ++++++++++++++++++ tests/ut/python/dataset/test_iterator.py | 4 +- 16 files changed, 1156 insertions(+), 8 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/filter_op.h create mode 100644 tests/ut/cpp/dataset/filter_op_test.cc create mode 100644 tests/ut/data/dataset/declient_filter.cfg create mode 100644 tests/ut/python/dataset/test_filterop.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index f6440710b1..a02d995147 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -29,6 +29,7 @@ #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/celeba_op.h" #include "dataset/engine/datasetops/source/text_file_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "mindrecord/include/shard_category.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" @@ -45,6 +46,7 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kShuffle, &DEPipeline::ParseShuffleOp}, {kMindrecord, &DEPipeline::ParseMindRecordOp}, {kMap, &DEPipeline::ParseMapOp}, + {kFilter, &DEPipeline::ParseFilterOp}, {kBatch, &DEPipeline::ParseBatchOp}, {kRepeat, &DEPipeline::ParseRepeatOp}, {kSkip, &DEPipeline::ParseSkipOp}, @@ -502,6 +504,41 @@ Status DEPipeline::ParseMapOp(const py::dict &args, std::shared_ptr * return Status::OK(); } +Status DEPipeline::ParseFilterOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + + if (args["predicate"].is_none()) { + RETURN_STATUS_UNEXPECTED("Error: 'predicate' is not set. \n"); + } + + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "num_parallel_workers") { + (void)builder->SetNumWorkers(ToInt(value)); + } else if (key == "predicate") { + py::handle op = args["predicate"]; + if (!py::isinstance(op)) { + RETURN_STATUS_UNEXPECTED("Error: predicate is not recognised (not pyfunc)."); + } + py::function predicate_func = op.cast(); + (void)builder->SetPredicateFunc(std::move(predicate_func)); + } else if (key == "input_columns") { + std::vector in_col_names = ToStringVector(args["input_columns"]); + (void)builder->SetInColNames(in_col_names); + } else { + RETURN_STATUS_UNEXPECTED("Error: Unhandled key: " + key); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr) { if (args["count"].is_none()) { std::string err_msg = "Error: count is invalid or not set."; @@ -671,8 +708,6 @@ Status DEPipeline::ParseZipOp(const py::dict &args, std::shared_ptr * return Status::OK(); } -DsOpPtr DEPipeline::ParseFilterOp(const py::dict &args) const { return DsOpPtr(); } - Status DEPipeline::ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr) { // Required arguments std::shared_ptr builder = std::make_shared(); diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index eadde2c191..25919afe58 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -107,6 +107,8 @@ class DEPipeline { Status ParseMapOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseFilterOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseRepeatOp(const py::dict &args, std::shared_ptr *ptr); Status ParseSkipOp(const py::dict &args, std::shared_ptr *ptr); @@ -121,8 +123,6 @@ class DEPipeline { Status ParseZipOp(const py::dict &args, std::shared_ptr *ptr); - DsOpPtr ParseFilterOp(const py::dict &args) const; - Status ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr); Status ParseTFReaderOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index b865c54260..15064dee6b 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -31,6 +31,7 @@ #include "dataset/engine/datasetops/map_op.h" #include "dataset/engine/datasetops/project_op.h" #include "dataset/engine/datasetops/rename_op.h" +#include "dataset/engine/datasetops/filter_op.h" #include "dataset/engine/datasetops/repeat_op.h" #include "dataset/engine/datasetops/skip_op.h" #include "dataset/engine/datasetops/shuffle_op.h" diff --git a/mindspore/ccsrc/dataset/core/tensor.cc b/mindspore/ccsrc/dataset/core/tensor.cc index a566d51f5c..3f41f27726 100644 --- a/mindspore/ccsrc/dataset/core/tensor.cc +++ b/mindspore/ccsrc/dataset/core/tensor.cc @@ -240,7 +240,7 @@ void Tensor::PrintItemAt(const std::vector &index, std::ostream &out) c DS_ASSERT(data_); switch (type_.value()) { - CASE_PRINT_HEX(DataType::DE_BOOL, uint8_t); + CASE_PRINT_HEX(DataType::DE_BOOL, bool); CASE_PRINT_HEX(DataType::DE_INT8, int8_t); diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 655a739ada..7de62d9d11 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -14,5 +14,6 @@ add_library(engine-datasetops OBJECT take_op.cc shuffle_op.cc zip_op.cc + filter_op.cc ) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc new file mode 100644 index 0000000000..22b1155fc9 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc @@ -0,0 +1,273 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/filter_op.h" +#include +#include +#include +#include +#include +#include "dataset/core/config_manager.h" +#include "dataset/core/constants.h" +#include "dataset/core/global_context.h" +#include "dataset/core/tensor.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/db_connector.h" +#include "dataset/engine/execution_tree.h" +#include "dataset/kernels/tensor_op.h" +#include "utils/log_adapter.h" +#include "dataset/util/task_manager.h" + +namespace mindspore { +namespace dataset { + +Status FilterOp::Builder::SanityCheck() { + std::string err; + err += builder_op_connector_size_ <= 0 ? "connector size <= 0\n" : ""; + err += builder_num_workers_ <= 0 ? "filter num_parallel_workers <= 0\n" : ""; + return err.empty() ? Status::OK() : Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, common::SafeCStr(err)); +} + +FilterOp::Builder::Builder() { + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_num_workers_ = cfg->num_parallel_workers(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status FilterOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(std::move(build_in_col_names_), builder_num_workers_, builder_op_connector_size_, + builder_predicate_func_); + return Status::OK(); +} + +FilterOp::FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func) + : ParallelOp(num_workers, op_queue_size), predicate_func_(std::move(predicate_func)), in_columns_(in_col_names) {} + +Status FilterOp::operator()() { + // The operator class just starts off threads by calling the tree_ function. + RETURN_UNEXPECTED_IF_NULL(tree_); + // Synchronize with TaskManager. + TaskManager::FindMe()->Post(); + filter_queues_.Init(num_workers_, oc_queue_size_); + RETURN_IF_NOT_OK(filter_queues_.Register(tree_->AllTasks())); + RETURN_IF_NOT_OK(tree_->LaunchWorkers(num_workers_, std::bind(&FilterOp::WorkerEntry, this, std::placeholders::_1))); + RETURN_IF_NOT_OK(Collector()); + return Status::OK(); +} + +Status FilterOp::EofReceived(int32_t) { return Status::OK(); } + +Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } + +// Validating if each of the input_columns exists in the DataBuffer. +Status FilterOp::ValidateInColumns(const std::unordered_map &col_name_id_map, + std::vector *input_columns) { + for (const auto &inCol : *input_columns) { + bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; + if (!found) { + std::string err_msg = "input column name: " + inCol + " doesn't exist in the dataset columns."; + RETURN_STATUS_UNEXPECTED(err_msg); + } + } + return Status::OK(); +} + +// A print method typically used for debugging. +void FilterOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first. + ParallelOp::Print(out, show_all); + + // Then display our own stuff. + out << "\nFilterOp:"; + out << "\n Input column names:"; + for (size_t i = 0; i < in_columns_.size(); i++) { + out << " " << in_columns_[i]; + } +} + +Status FilterOp::WorkerEntry(int32_t worker_id) { + // Handshake with TaskManager that thread creation is successful. + TaskManager::FindMe()->Post(); + std::unique_ptr in_buffer; + bool worker_stop = false; + while (worker_stop == false) { + // Getting a databuffer to work on. + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&in_buffer, worker_id)); + if (in_buffer->eoe()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEoe)); + continue; + } else if (in_buffer->eof()) { + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEof)); + worker_stop = true; + continue; + } + + // Thread local variables to avoid lock. When in_columns_ is empty and workers will write + // the name of the first column into input_columns (thread local) instead of in_columns_ (thread global). + std::vector input_columns = in_columns_; + // Indices of the columns to process. + std::vector to_process_indices; + + RETURN_IF_NOT_OK(WorkerEntryInit(in_buffer.get(), &to_process_indices, &input_columns)); + + // if the databuffer was all filtered, it is marked as kFilterEmpty. + // if the databuffer was partially filtered, it is marked as kFilterPartial. + // if the databuffer was not filtered, it is marked as kFilterFull. + int32_t num_rows = in_buffer->NumRows(); + std::unique_ptr new_tensor_table; + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), to_process_indices, &new_tensor_table)); + + if (new_tensor_table->empty()) { + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterEmpty))); + } else if (new_tensor_table->size() == num_rows) { + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterFull))); + } else { // kFilterPartial + in_buffer->set_tensor_table(std::move(new_tensor_table)); + RETURN_IF_NOT_OK( + filter_queues_[worker_id]->EmplaceBack(std::make_pair(std::move(in_buffer), filterCtrl::kFilterPartial))); + } + } + return Status::OK(); +} + +Status FilterOp::WorkerCompute(DataBuffer *in_buffer, const std::vector &to_proess_indices, + std::unique_ptr *out) { + *out = std::make_unique(); + int32_t num_rows = in_buffer->NumRows(); + for (int32_t i = 0; i < num_rows; i++) { + TensorRow to_process; + TensorRow cur_row; + RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); + + (void)std::transform(to_proess_indices.begin(), to_proess_indices.end(), std::back_inserter(to_process), + [&cur_row](const size_t &it) -> std::shared_ptr { return cur_row[it]; }); + bool predicate = true; + RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); + if (predicate) { + (*out)->push_back(std::move(cur_row)); + } + } + return Status::OK(); +} + +// if the filtered DataBuffer is written directly to out_connector_, +// the thread fetching data will block in a queue. +// Collector function will reorder the DataBuffer in order. +// for example in two work queues: +// int filter_queues_: +// queue1: DB(data1 kFilterEmpty) DB(eoe) DB(data4) DB(eof) +// queue2: DB(data2) DB(data3 kFilterEmpty) DB(eoe) +// after reorder in out_connector_: +// queue1: DB(data2) DB(data4) DB(eof) +// queue2: DB(eoe) DB(eoe) +Status FilterOp::Collector() { + bool collector_stop = false; + uint64_t task_id_cnt = 0; + uint64_t out_id_cnt = 0; + std::pair, filterCtrl> in_pair; + while (collector_stop == false) { + uint32_t w_id = task_id_cnt % num_workers_; + RETURN_IF_NOT_OK(filter_queues_[w_id]->PopFront(&in_pair)); + if (in_pair.second == filterCtrl::kFilterFull || in_pair.second == filterCtrl::kFilterPartial || + in_pair.second == filterCtrl::kFilterEoe) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + out_id_cnt++; + task_id_cnt++; + } else if (in_pair.second == filterCtrl::kFilterEof) { + uint32_t out_task_id = out_id_cnt % num_workers_; + RETURN_IF_NOT_OK(out_connector_->Add(static_cast(out_task_id), std::move(in_pair.first))); + collector_stop = true; + } else { // kFilterEmpty + task_id_cnt++; + } + } + return Status::OK(); +} + +// initialize some internal data structure used by WorkerEntry(). +Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector *to_process_indices, + std::vector *input_columns) { + int32_t num_rows = in_buf->NumRows(); + int32_t num_cols = in_buf->NumCols(); + if (num_rows == 0 || num_cols == 0) { + RETURN_STATUS_UNEXPECTED("FilterOp is getting an empty DataBuffer."); + } + std::unordered_map col_name_id_map = in_buf->column_name_map(); + // Check if there is invalid column name in the inColumns. + RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns)); + + if (input_columns->empty()) { + MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; + // sort the input colunms by column index. + std::vector> sort_vec(col_name_id_map.begin(), col_name_id_map.end()); + std::sort(sort_vec.begin(), sort_vec.end(), + [](const std::pair &a, const std::pair &b) { + return a.second < b.second; + }); + + (void)std::transform(sort_vec.begin(), sort_vec.end(), std::back_inserter(*input_columns), + [](const auto &it) -> std::string { return it.first; }); + } + + // initialize to_process_indices. + (void)std::transform(input_columns->begin(), input_columns->end(), std::back_inserter(*to_process_indices), + [&col_name_id_map](const auto &it) -> size_t { return col_name_id_map[it]; }); + + return Status::OK(); +} + +Status FilterOp::CheckInput(const TensorRow &input) const { + for (auto &item : input) { + if (item == nullptr) { + RETURN_STATUS_UNEXPECTED("input is null."); + } + } + return Status::OK(); +} + +Status FilterOp::InvokePredicateFunc(const TensorRow &input, bool *out_predicate) { + RETURN_IF_NOT_OK(CheckInput(input)); + // Acquire Python GIL. + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + try { + // Transform input tensor vector into numpy array vector. + py::tuple input_args(input.size()); + for (size_t i = 0; i < input.size(); i++) { + py::array new_data; + RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data)); + input_args[i] = new_data; + } + // Invoke python function. + py::object ret_py_obj = predicate_func_(*input_args); + *out_predicate = ret_py_obj.cast(); + } catch (const py::error_already_set &e) { + std::stringstream ss; + ss << e.what() << std::endl; + ss << "The type of the return value of python predicate function is not bool, or can not be convert to bool."; + return Status(StatusCode::kPyFuncException, ss.str()); + } + return Status(StatusCode::kOK, "FilterOp predicate func call succeed"); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h new file mode 100644 index 0000000000..50697d398f --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h @@ -0,0 +1,180 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_FILTER_OP_H_ + +#include +#include +#include +#include +#include +#include +#include "dataset/engine/datasetops/parallel_op.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/queue.h" + +namespace mindspore { +namespace dataset { + +class FilterOp : public ParallelOp { + public: + // The nested builder class inside of the FilterOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args. + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetPredicateFunc(py::function func) { + builder_predicate_func_ = std::move(func); + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetInColNames(const std::vector &in_col_names) { + build_in_col_names_ = in_col_names; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetNumWorkers(int32_t num_workers) { + builder_num_workers_ = num_workers; + return *this; + } + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t connector_size) { + builder_op_connector_size_ = connector_size; + return *this; + } + + // The builder "build" method creates the final object. + // @param ptr The shared_ptr to the new FilterOp object. + // @return Status. + Status Build(std::shared_ptr *ptr); + + private: + // Sanity check for builder class args. + // @return Status - The error code return. + Status SanityCheck(); + std::vector build_in_col_names_; + py::function builder_predicate_func_; + int32_t builder_num_workers_; + int32_t builder_op_connector_size_; + }; + + enum filterCtrl : int8_t { kFilterEmpty = 0, kFilterPartial = 1, kFilterFull = 2, kFilterEoe = 3, kFilterEof = 4 }; + + // Constructor of FilterOp + // @note The builder class should be used to call it. + // @param in_col_names A list of input column names,when it is empty the predicate will be + // applied all columns in the dataset. + // @param num_workers The number of worker threads. + // @param op_connector_size The size of each queue in the connector. + // @param predicate_func python callable which returns a boolean value. + FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, + py::function predicate_func); + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will + // provide the master loop that drives the logic for performing the work. + // @return Status The error code return + Status operator()() override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EofReceived(int32_t) override; + + // @param int32_t workerId. + // @return Status - The error code return. + Status EoeReceived(int32_t) override; + + // A print method typically used for debugging. + // @param out The output stream to write output to. + // @param show_all A bool to control if you want to show all info or just a summary. + void Print(std::ostream &out, bool show_all) const override; + + private: + // predicate_func python callable which returns a boolean value. + py::function predicate_func_; + + // Variable to store the column name that will feed to predicate function. + std::vector in_columns_; + + // Internal queue for filter. + QueueList, filterCtrl>> filter_queues_; + + // Private function for worker/thread to loop continuously. It comprises the main + // logic of FilterOp, getting the data from previous Op, validating user specified column names, + // applying predicate to each of the data, filter the data when predicate result is false. + // @param worker_id The id assigned to this thread/worker upon creation. + // @return Status The error code return. + Status WorkerEntry(int32_t worker_id) override; // In: workerId assigned by tree_ + + // Filter the data by predicate function . + // @param in_buffer input data buffer. + // @param to_proess_indices Indices of columns to be processed. + // @param out data buffer that are filtered by predicate. + // @return Status The error code return. + Status WorkerCompute(DataBuffer *in_buffer, const std::vector &to_proess_indices, + std::unique_ptr *out); + + // Collector databuffer. + // @return Status The error code return. + Status Collector(); + + // @param input tensor vector. + // @return Status - The error code return. + Status CheckInput(const TensorRow &input) const; + + // Invoke python func. + // @param input tensor vector. + // @param the result of predicate. + // @return Status - The error code return. + Status InvokePredicateFunc(const TensorRow &input, bool *out_predicate); + + // Private function for validating if each of the user specified input column names + // exist in the DataBuffer. + // @param col_name_id_map The column name to index mapping obtained from DataBuffer. + // @param input_columns The vector of input column names used in the current thread. + // @return Status The error code return. + Status ValidateInColumns(const std::unordered_map &col_name_id_map, + std::vector *input_columns); + + // Private function that initialize some internal data structure used by WorkerEntry(). + // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory + // and is not shared with other threads. + // @param[out] to_process_indices Indices of columns that will feed to predicate. + // @param input_columns The vector of input column names used in the current thread. + Status WorkerEntryInit(const DataBuffer *in_buf, std::vector *to_process_indices, + std::vector *input_columns); +}; + +} // namespace dataset +} // namespace mindspore +#endif diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index ca717643c9..89842df015 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -35,7 +35,7 @@ from mindspore._c_expression import typing from mindspore import log as logger from . import samplers from .iterators import DictIterator, TupleIterator -from .validators import check, check_batch, check_shuffle, check_map, check_repeat, check_skip, check_zip, check_rename, \ +from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ check_zip_dataset, check_add_column, check_textfiledataset @@ -385,6 +385,32 @@ class Dataset: """ return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) + @check_filter + def filter(self, predicate, input_columns=None, num_parallel_workers=1): + """ + Filter dataset by predicate. + + Note: + If input_columns not provided or empty, all columns will be used. + + Args: + predicate: python callable which returns a boolean value. + input_columns: (list[str]): List of names of the input columns, when + default=None, the predicate will be applied on all columns in the dataset. + num_parallel_workers (int, optional): Number of workers to process the Dataset + in parallel (default=None). + + Returns: + FilterDataset, dataset filter. + + Examples: + >>> import mindspore.dataset as ds + >>> # generator data(0 ~ 63) + >>> # filter the data that greater than or equal to 11 + >>> dataset_f = dataset.filter(predicate=lambda data: data < 11, input_columns = ["data"]) + """ + return FilterDataset(self, predicate, input_columns, num_parallel_workers) + @check_repeat def repeat(self, count=None): """ @@ -1105,6 +1131,44 @@ class MapDataset(DatasetOp): return self.input[0].get_dataset_size() +class FilterDataset(DatasetOp): + """ + The result of applying filter predicate to the input Dataset. + + Args: + input_dataset: Input Dataset to be mapped. + predicate: python callable which returns a boolean value. + input_columns: (list[str]): List of names of the input columns, when + default=None, the predicate will be applied all columns in the dataset. + num_parallel_workers (int, optional): Number of workers to process the Dataset + in parallel (default=None). + """ + + def __init__(self, input_dataset, predicate, input_columns=None, num_parallel_workers=None): + super().__init__(num_parallel_workers) + self.predicate = lambda *args: bool(predicate(*args)) + self.input.append(input_dataset) + input_dataset.output.append(self) + if input_columns is not None and not isinstance(input_columns, list): + input_columns = [input_columns] + self.input_columns = input_columns + + def get_args(self): + args = super().get_args() + args["predicate"] = self.predicate + args["input_columns"] = self.input_columns + return args + + def get_dataset_size(self): + """ + Get the number of batches in an epoch. + the size cannot be determined before we run the pipeline + Return: + 0 + """ + return 0 + + class RepeatDataset(DatasetOp): """ The result of applying Repeat operator to the input Dataset. diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index a74d69b9c7..6af6c7dba8 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -129,6 +129,8 @@ class Iterator: op_type = OpName.ZIP elif isinstance(dataset, de.MapDataset): op_type = OpName.MAP + elif isinstance(dataset, de.FilterDataset): + op_type = OpName.FILTER elif isinstance(dataset, de.RepeatDataset): op_type = OpName.REPEAT elif isinstance(dataset, de.SkipDataset): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a340eb5aff..324cbde03a 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -695,6 +695,26 @@ def check_map(method): return new_method +def check_filter(method): + """"check the input arguments of filter.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + predicate = param_dict.get("predicate") + if not callable(predicate): + raise ValueError("Predicate should be a python function or a callable python object.") + + nreq_param_int = ['num_parallel_workers'] + check_param_type(nreq_param_int, param_dict, int) + param_name = "input_columns" + param = param_dict.get(param_name) + if param is not None: + check_columns(param, param_name) + return method(*args, **kwargs) + + return new_method + + def check_repeat(method): """check the input arguments of repeat.""" @wraps(method) diff --git a/tests/ut/cpp/dataset/CMakeLists.txt b/tests/ut/cpp/dataset/CMakeLists.txt index b05f12eee1..2224565c30 100644 --- a/tests/ut/cpp/dataset/CMakeLists.txt +++ b/tests/ut/cpp/dataset/CMakeLists.txt @@ -66,6 +66,8 @@ SET(DE_UT_SRCS celeba_op_test.cc take_op_test.cc text_file_op_test.cc) + filter_op_test.cc + ) add_executable(de_ut_tests ${DE_UT_SRCS}) diff --git a/tests/ut/cpp/dataset/filter_op_test.cc b/tests/ut/cpp/dataset/filter_op_test.cc new file mode 100644 index 0000000000..45ee714337 --- /dev/null +++ b/tests/ut/cpp/dataset/filter_op_test.cc @@ -0,0 +1,53 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/util/circular_pool.h" +#include "dataset/core/client.h" +#include "common/common.h" +#include "gtest/gtest.h" +#include "utils/log_adapter.h" + +using namespace mindspore::dataset; +namespace de = mindspore::dataset; + +using mindspore::MsLogLevel::INFO; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::LogStream; + +class MindDataTestfilter_op : public UT::DatasetOpTesting { + +}; + + +std::shared_ptr Filter() { + Status rc; + std::shared_ptr op; + rc = de::FilterOp::Builder().Build(&op); + EXPECT_TRUE(rc.IsOk()); + return op; +} + +TEST_F(MindDataTestfilter_op, Testfilter_opFuntions) { + MS_LOG(INFO) << "Doing MindDataTest filter_op."; + auto my_tree = std::make_shared(); + + std::shared_ptr parent_op = Filter(); + + std::shared_ptr leaf_op = Filter(); + my_tree->AssociateNode(parent_op); + my_tree->AssociateNode(leaf_op); + ASSERT_NE(parent_op, nullptr); + ASSERT_NE(leaf_op, nullptr); +} diff --git a/tests/ut/cpp/dataset/tensor_test.cc b/tests/ut/cpp/dataset/tensor_test.cc index 7437b3d942..494d4b2329 100644 --- a/tests/ut/cpp/dataset/tensor_test.cc +++ b/tests/ut/cpp/dataset/tensor_test.cc @@ -158,6 +158,16 @@ TEST_F(MindDataTestTensorDE, InsertTensor) { ASSERT_EQ(*t == *t6, true); } +// Test the bug of Tensor::ToString will exec failed for Tensor which store bool values +TEST_F(MindDataTestTensorDE, BoolTensor) { + std::shared_ptr t = std::make_shared(TensorShape({2}), + DataType(DataType::DE_BOOL)); + t->SetItemAt({0}, true); + t->SetItemAt({1}, true); + std::string out = t->ToString(); + ASSERT_TRUE(out.find("Template type and Tensor type are not compatible") == std::string::npos); +} + TEST_F(MindDataTestTensorDE, GetItemAt) { std::shared_ptr t = std::make_shared(TensorShape({2, 2}), DataType(DataType::DE_UINT8)); t->Fill(254); diff --git a/tests/ut/data/dataset/declient_filter.cfg b/tests/ut/data/dataset/declient_filter.cfg new file mode 100644 index 0000000000..89e1199f5a --- /dev/null +++ b/tests/ut/data/dataset/declient_filter.cfg @@ -0,0 +1,3 @@ +{ + "rowsPerBuffer": 10, +} diff --git a/tests/ut/python/dataset/test_filterop.py b/tests/ut/python/dataset/test_filterop.py new file mode 100644 index 0000000000..90f512caa4 --- /dev/null +++ b/tests/ut/python/dataset/test_filterop.py @@ -0,0 +1,504 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as cde +import mindspore.dataset.transforms.c_transforms as C +import mindspore.common.dtype as mstype +from mindspore import log as logger + +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" +# test for predicate +def test_diff_predicate_func(): + def test_filter(predicate_func): + transforms = [ + cde.Decode(), + cde.Resize([64, 64]) + ] + type_cast_op = C.TypeCast(mstype.int32) + dataset = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image", "label"], shuffle=False) + dataset = dataset.map(input_columns=["image"], operations=transforms, num_parallel_workers=1) + dataset = dataset.filter(input_columns=["image", "label"], predicate=predicate_func, num_parallel_workers=4) + + num_iter = 0 + label_list = [] + for data in dataset.create_dict_iterator(): + num_iter += 1 + ori_img = data["image"] + label = data["label"] + label_list.append(label) + assert num_iter == 1 + assert label_list[0] == 3 + + test_filter(lambda image, label: label == 3) + test_filter(lambda image, label: label[0] == 3) + test_filter(lambda image, label: label == [3]) + test_filter(lambda image, label: label == np.array([3])) + test_filter(lambda image, label: label == np.array(3)) + +def filter_func_ge(data): + if data > 10: + return False + return True + + +def generator_1d(): + for i in range(64): + yield (np.array(i),) + +# test with GeneratorDataset +def test_filter_by_generator_with_no(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=lambda data: data < 11, num_parallel_workers=4) + num_iter = 0 + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_f.create_dict_iterator(): + assert item["data"] == expected_rs[num_iter] + num_iter += 1 + +# test with repeatOp before +def test_filter_by_generator_with_repeat(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_r = dataset.repeat(4) + dataset_f = dataset_r.filter(predicate=filter_func_ge, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 44 + for i in range(4): + for ii in range(len(expected_rs)): + index = i * len(expected_rs) + ii + assert ret_data[index] == expected_rs[ii] + +# test with repeatOp after +def test_filter_by_generator_with_repeat_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_ge, num_parallel_workers=4) + dataset_r = dataset_f.repeat(4) + num_iter = 0 + ret_data = [] + expected_rs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + for item in dataset_r.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 44 + for i in range(4): + for ii in range(len(expected_rs)): + index = i * len(expected_rs) + ii + assert ret_data[index] == expected_rs[ii] + +def filter_func_batch(data): + if data[0] > 8: + return False + return True + +def filter_func_batch_after(data): + if data > 20: + return False + return True + +# test with batchOp before +def test_filter_by_generator_with_batch(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_b = dataset.batch(4) + dataset_f = dataset_b.filter(predicate=filter_func_batch, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 3 + assert ret_data[0][0] == 0 + assert ret_data[1][0] == 4 + assert ret_data[2][0] == 8 + +# test with batchOp after +def test_filter_by_generator_with_batch_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_batch_after, num_parallel_workers=4) + dataset_b = dataset_f.batch(4) + num_iter = 0 + ret_data = [] + for item in dataset_b.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["data"]) + assert num_iter == 6 + assert ret_data[0][0] == 0 + assert ret_data[1][0] == 4 + assert ret_data[5][0] == 20 + + +def filter_func_shuffle(data): + if data > 20: + return False + return True + +# test with batchOp before +def test_filter_by_generator_with_shuffle(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_s = dataset.shuffle(4) + dataset_f = dataset_s.filter(predicate=filter_func_shuffle, num_parallel_workers=4) + num_iter = 0 + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + assert num_iter == 21 + + +def filter_func_shuffle_after(data): + if data > 20: + return False + return True + +# test with batchOp after +def test_filter_by_generator_with_shuffle_after(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_f = dataset.filter(predicate=filter_func_shuffle_after, num_parallel_workers=4) + dataset_s = dataset_f.shuffle(4) + num_iter = 0 + for item in dataset_s.create_dict_iterator(): + num_iter += 1 + assert num_iter == 21 + + +def generator_1d_zip1(): + for i in range(64): + yield (np.array(i),) + + +def generator_1d_zip2(): + for i in range(64): + yield (np.array(i+100),) + + +def filter_func_zip(data1, data2): + if data1 > 20: + return False + return True + +def filter_func_zip_after(data1): + if data1 > 20: + return False + return True + +# test with zipOp before +def test_filter_by_generator_with_zip(): + dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) + dataset2 = ds.GeneratorDataset(generator_1d_zip2, ["data2"]) + dataz = ds.zip((dataset1, dataset2)) + dataset_f = dataz.filter(predicate=filter_func_zip, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append({"data1": item["data1"], "data2":item["data2"]}) + assert num_iter == 21 + assert ret_data[0]["data1"] == 0 + assert ret_data[0]["data2"] == 100 + assert ret_data[5]["data1"] == 5 + assert ret_data[5]["data2"] == 105 + + +# test with zipOp after +def test_filter_by_generator_with_zip_after(): + dataset1 = ds.GeneratorDataset(generator_1d_zip1, ["data1"]) + dataset2 = ds.GeneratorDataset(generator_1d_zip1, ["data2"]) + dt1 = dataset1.filter(predicate=filter_func_zip_after, num_parallel_workers=4) + dt2 = dataset2.filter(predicate=filter_func_zip_after, num_parallel_workers=4) + dataz = ds.zip((dt1, dt2)) + num_iter = 0 + ret_data = [] + for item in dataz.create_dict_iterator(): + num_iter += 1 + ret_data.append({"data1": item["data1"], "data2":item["data2"]}) + assert num_iter == 21 + assert ret_data[0]["data1"] == 0 + assert ret_data[0]["data2"] == 0 + assert ret_data[5]["data1"] == 5 + assert ret_data[5]["data2"] == 5 + + +def filter_func_map(col1, col2): + if col1[0] > 8: + return True + return False + + +def filter_func_map_part(col1): + if col1 < 3: + return True + else: + return False + + +def filter_func_map_all(col1, col2): + return True + +def generator_mc(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) + + +def func_map(data_col1, data_col2): + return (data_col1, data_col2) + + +def func_map_part(data_col1): + return (data_col1) + +# test with map +def test_filter_by_generator_with_map_all_col(): + dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["col1"] , operations=func_map_part) + # dataset_map = dataset.map( operations=func_map_part) + dataset_f = dataset_map.filter(input_columns=["col1"], predicate=filter_func_map_part, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["col1"]) + assert num_iter == 3 + assert ret_data[0] == 0 + assert ret_data[1] == 1 + +# test with map +def test_filter_by_generator_with_map_part_col(): + dataset = ds.GeneratorDataset(generator_mc(12), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) + + dataset_f = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_map, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + print(item) + ret_data.append(item["out1"]) + assert num_iter == 3 + assert ret_data[0] == 9 + assert ret_data[2] == 11 + + +def filter_func_rename(data): + if data> 8: + return True + return False + +# test with rename before +def test_filter_by_generator_with_rename(): + dataset = ds.GeneratorDataset(generator_1d, ["data"]) + dataset_b = dataset.rename(input_columns=["data"], output_columns=["col1"]) + dataset_f = dataset_b.filter(predicate=filter_func_rename, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["col1"]) + assert num_iter == 55 + assert ret_data[0] == 9 + assert ret_data[54] == 63 + + +#test input_column +def filter_func_input_column1(col1, col2): + if col1[0] < 8: + return True + return False + +def filter_func_input_column2(col1): + if col1[0] < 8: + return True + return False + +def filter_func_input_column3(col1): + return True + +# test with input_columns +def test_filter_by_generator_with_input_column(): + dataset = ds.GeneratorDataset(generator_mc(64), ["col1", "col2"]) + dataset_map = dataset.map( input_columns=["col1"], output_columns=["out1"] , operations=func_map_part) + dataset_f1 = dataset_map.filter(input_columns=["out1", "col2"], predicate=filter_func_input_column1, num_parallel_workers=4) + dataset_f2 = dataset_f1.filter(input_columns=["out1"], predicate=filter_func_input_column2, num_parallel_workers=4) + dataset_f3 = dataset_f2.filter(input_columns=["col2"], predicate=filter_func_input_column3, num_parallel_workers=4) + dataset_f4 = dataset_f3.filter(predicate=filter_func_input_column1, num_parallel_workers=4) + num_iter = 0 + ret_data = [] + for item in dataset_f4.create_dict_iterator(): + num_iter += 1 + ret_data.append(item["out1"]) + assert num_iter == 8 + assert ret_data[0] == 0 + assert ret_data[7] == 7 + + +#test kFilterPartial +def generator_mc_p0(maxid=20): + for i in range(maxid): + yield (np.array([i ]), np.array([i + 100])) + +def generator_mc_p1(maxid=20): + for i in range(maxid): + yield (np.array([i + 200 ]), np.array([i + 300])) + + +def filter_func_Partial_0(col1, col2, col3, col4): + filter_data = [0,1,2,3,4, 11] + if col1[0] in filter_data: + return False + return True + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial0(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + dataset_zip = ds.zip((dataset1, dataset2)) + dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) + ret = [] + for item in dataset_f1.create_dict_iterator(): + ret.append(item["col1"]) + assert ret[0] == 5 + assert ret[6] == 12 + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial1(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + dataset_zip = ds.zip((dataset1, dataset2)) + dataset_f1 = dataset_zip.filter(predicate=filter_func_Partial_0, num_parallel_workers=2) + dataset_map = dataset_f1.map( input_columns=["col1"], output_columns=["out1"] , operations=lambda x1: x1 + 400) + ret = [] + for item in dataset_map.create_dict_iterator(): + ret.append(item["out1"]) + assert ret[0] == 405 + assert ret[6] == 412 + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial2(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset1= ds.GeneratorDataset(source = generator_mc_p0(), column_names = ["col1", "col2"]) + dataset2 = ds.GeneratorDataset(source = generator_mc_p1(), column_names = ["col3", "col4"]) + + dataset1f = dataset1.filter( input_columns= ["col1"], predicate=lambda x: x not in [3,7,9], num_parallel_workers=2) + dataset2f = dataset2.filter( input_columns= ["col3"], predicate=lambda x: x not in [203,207,209], num_parallel_workers=2) + dataset_zip = ds.zip((dataset1f, dataset2f)) + dataset_map = dataset_zip.map( input_columns=["col1", "col3"], output_columns=["out1", "out3"] , operations=lambda x1,x3: (x1 + 400, x3+500)) + ret1 = [] + ret3 = [] + for item in dataset_map.create_dict_iterator(): + ret1.append(item["out1"]) + ret3.append(item["out3"]) + assert ret1[0] == 400 + assert ret1[6] == 408 + assert ret3[0] == 700 + assert ret3[6] == 708 + + +def filter_func_Partial(col1, col2): + if col1[0] % 3 == 0: + return True + return False + +def generator_big(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]])) + +# test with row_data_buffer > 1 +def test_filter_by_generator_Partial(): + ds.config.load('../data/dataset/declient_filter.cfg') + dataset = ds.GeneratorDataset(source = generator_mc(99), column_names = ["col1", "col2"]) + dataset_s = dataset.shuffle(4) + dataset_f1 = dataset_s.filter(input_columns=["col1", "col2"], predicate=filter_func_Partial, num_parallel_workers=1) + + for item in dataset_f1.create_dict_iterator(): + assert item["col1"] % 3 == 0 + + +def filter_func_cifar(col1, col2): + if col2 % 3 == 0: + return True + return False + +# test with cifar10 +def test_filte_case_dataset_cifar10(): + DATA_DIR_10 = "../data/dataset/testCifar10Data" + ds.config.load('../data/dataset/declient_filter.cfg') + dataset_c = ds.Cifar10Dataset(dataset_dir = DATA_DIR_10, num_samples = 100000, shuffle=False) + dataset_f1 = dataset_c.filter(input_columns=["image", "label"], predicate=filter_func_cifar, num_parallel_workers=1) + num_iter = 0 + for item in dataset_f1.create_dict_iterator(): + # in this example, each dictionary has keys "image" and "label" + assert item["label"] % 3 == 0 + +# column id sort + +def generator_sort1(maxid=20): + for i in range(maxid): + yield (np.array([i]), np.array([i + 100]), np.array([i + 200])) + +def generator_sort2(maxid=20): + for i in range(maxid): + yield (np.array([i + 300]), np.array([i + 400]), np.array([i + 500])) + + +def filter_func_part_sort(col1, col2, col3, col4, col5, col6): + return True + +def filter_func_map_sort(col1, col2, col3): + return (col1, col2, col3) + +def test_filter_by_generator_with_map_all_sort(): + dataset1 = ds.GeneratorDataset(generator_sort1(10), ["col1", "col2", "col3"]) + dataset2 = ds.GeneratorDataset(generator_sort2(10), ["col4 ", "col5", "col6"]) + + dataz = ds.zip((dataset1, dataset2)) + dataset_f = dataz.filter(predicate=filter_func_part_sort, num_parallel_workers=1) + num_iter = 0 + ret_data = [] + for item in dataset_f.create_dict_iterator(): + num_iter += 1 + ret_data.append(item) + + assert num_iter == 10 + assert ret_data[0]["col1"] == 0 + assert ret_data[9]["col6"] == 509 + + + +if __name__ == '__main__': + test_diff_predicate_func() + test_filte_case_dataset_cifar10() + test_filter_by_generator_Partial0() + test_filter_by_generator_Partial1() + test_filter_by_generator_Partial2() + test_filter_by_generator_with_batch() + test_filter_by_generator_with_batch_after() + test_filter_by_generator_with_input_column() + test_filter_by_generator_with_map_all_col() + test_filter_by_generator_with_map_all_sort() + test_filter_by_generator_with_map_part_col() + test_filter_by_generator_with_no() + test_filter_by_generator_with_rename() + test_filter_by_generator_with_repeat() + test_filter_by_generator_with_repeat_after() + test_filter_by_generator_with_shuffle() + test_filter_by_generator_with_shuffle_after() + test_filter_by_generator_with_zip() + test_filter_by_generator_with_zip_after() + test_filter_by_generator_Partial() diff --git a/tests/ut/python/dataset/test_iterator.py b/tests/ut/python/dataset/test_iterator.py index 102fd0eea1..7c69adf561 100644 --- a/tests/ut/python/dataset/test_iterator.py +++ b/tests/ut/python/dataset/test_iterator.py @@ -25,8 +25,8 @@ COLUMNS = ["col_1d", "col_2d", "col_3d", "col_binary", "col_float", def check(project_columns): - data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS) - data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns) + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=COLUMNS, shuffle=False) + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=project_columns, shuffle=False) for data_actual, data_expected in zip(data1.create_tuple_iterator(project_columns), data2.create_tuple_iterator()): assert len(data_actual) == len(data_expected) From c71234f383973692ffd46d1f839632ca8a99e4b6 Mon Sep 17 00:00:00 2001 From: ch-l Date: Thu, 16 Apr 2020 22:30:04 +0200 Subject: [PATCH 024/142] improve rec-prog str generator --- .../rec_core/rec_generate_strategy.cc | 215 ++++++++++-------- .../rec_core/rec_generate_strategy.h | 45 ++-- .../ccsrc/parallel/step_auto_parallel.cc | 6 +- 3 files changed, 144 insertions(+), 122 deletions(-) diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc index 60f3003a42..b2c34127a1 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.cc @@ -27,44 +27,27 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(const std::shared_ptr graph, std::vector> ops, - const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list) { - MaskNoSupportedOps(graph); +void GenerateStrategy(std::shared_ptr graph, bool mask_special_ops, + const std::vector> &ops) { + MS_EXCEPTION_IF_NULL(graph); + if (mask_special_ops) { + MaskSpecialOps(graph); + } for (size_t iter_ops = 0; iter_ops < ops.size(); iter_ops++) { - auto type = ops[iter_ops]->type(); - size_t iter_nodes = index_list->at(ops_nodes_list->at(iter_ops)); std::vector> stra; - iter_nodes = IterNodes(ops_nodes_list, index_list, eli_list, iter_ops, iter_nodes); for (size_t iter_op_inputs = 0; iter_op_inputs < ops[iter_ops]->inputs_tensor_info().size(); iter_op_inputs++) { - std::vector s = PrepareStrategy(graph, ops, type, iter_ops, iter_nodes, iter_op_inputs); - stra.push_back(s); + stra.push_back(PrepareStrategy(graph, ops, iter_ops, iter_op_inputs)); } StrategyPtr sp = std::make_shared(0, stra); ops[iter_ops]->SetSelectedStrategyAndCost(sp, ops[iter_ops]->selected_cost()); } } -size_t IterNodes(const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list, const size_t iter_ops, - size_t iter_nodes) { - if (iter_nodes > SIZE_MAX / 2) { - for (size_t iter_eli = 0; iter_eli < eli_list->size(); iter_eli++) { - if (eli_list->at(iter_eli)[0] == ops_nodes_list->at(iter_ops)) { - iter_nodes = index_list->at(eli_list->at(iter_eli)[1]); - break; - } - } - } - return iter_nodes; -} - -void PrepareMatMul(const std::shared_ptr graph, const std::vector> &ops, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s) { - auto attrs = ops[iter_ops]->attrs(); +std::vector PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_nodes, + const size_t iter_op_inputs) { + std::vector s; + auto attrs = ops[iter_nodes]->attrs(); bool transpose_a = attrs[TRANSPOSE_A]->cast()->value(); bool transpose_b = attrs[TRANSPOSE_B]->cast()->value(); if (transpose_a && (iter_op_inputs == 0)) { @@ -77,10 +60,12 @@ void PrepareMatMul(const std::shared_ptr graph, const std::vector(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); } + return s; } -void PrepareConv2D(const std::shared_ptr graph, const size_t iter_nodes, size_t iter_op_inputs, - std::vector s) { +std::vector PrepareConv2D(const std::shared_ptr &graph, const size_t iter_nodes, + size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c)); @@ -92,20 +77,24 @@ void PrepareConv2D(const std::shared_ptr graph, const size_t iter_nodes, s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_h)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w)); } + return s; } -void PrepareBiasAdd(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s) { +std::vector PrepareBiasAdd(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); } else { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); } + return s; } -void PrepareBN(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s) { +std::vector PrepareBN(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_n)); s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[0].tensor_str.str_c)); @@ -114,97 +103,133 @@ void PrepareBN(const std::shared_ptr graph, const size_t iter_nodes, cons } else { s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[1].tensor_str.str_w)); } + return s; } -void PrepareSparse(const size_t iter_op_inputs, std::vector s) { +std::vector PrepareSparse(const size_t iter_op_inputs) { + std::vector s; if (iter_op_inputs == 0) { s.push_back(g_device_manager->DeviceNum()); s.push_back(1); } else { s.push_back(g_device_manager->DeviceNum()); } + return s; +} + +std::vector MakeOriginalStrategy(const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs) { + std::vector s; + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + if (iter_op_inputs >= ops[iter_ops]->strategy()->GetInputDim().size()) + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + size_t input_size = ops[iter_ops]->strategy()->GetInputDim()[iter_op_inputs].size(); + for (size_t dim = 0; dim < input_size; dim++) { + s.push_back(1); + } + return s; } -void RefillOrigin(const std::vector> &ops, const size_t iter_ops, - const size_t iter_op_inputs, std::vector s) { +std::vector MakeRecSearchStrategy(const std::shared_ptr &graph, const size_t iter_ops, + const size_t iter_op_inputs) { + std::vector s; + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_n)); + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_c)); + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_h)); + s.push_back(static_cast(1.0 / graph->nodes[iter_ops].apply.arguments[iter_op_inputs].tensor_str.str_w)); + return s; +} + +std::vector MakeDataParallelStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t iter_op_inputs) { + std::vector s; + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } StrategyPtr origin_strategy = ops[iter_ops]->strategy(); - if (iter_op_inputs == 0) { - for (size_t j = 0; j < origin_strategy->GetInputDim()[0].size(); j++) { - s.push_back(1); - } - } else { - for (size_t k = 0; k < origin_strategy->GetInputDim()[iter_op_inputs].size(); k++) { + if (iter_op_inputs >= origin_strategy->GetInputDim().size()) + MS_LOG(EXCEPTION) << "Failure: Strategy's InputDim out of range."; + size_t input_size = origin_strategy->GetInputDim()[iter_op_inputs].size(); + for (size_t dim = 0; dim < input_size; dim++) { + if (dim == 0 && input_size == 4) { + size_t max_device_num = g_device_manager->DeviceNum(); + size_t target_tensor_batch = ops[iter_ops]->outputs_tensor_info()[0].shape()[0]; + s.push_back(std::min(max_device_num, target_tensor_batch)); + } else { s.push_back(1); } } + return s; } -std::vector PrepareStrategy(const std::shared_ptr graph, - const std::vector> &ops, const std::string &type, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs) { - std::vector s; +std::vector PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs) { + if (ops.empty()) { + MS_LOG(EXCEPTION) << "Failure: Operators is empty."; + } + if (iter_ops >= ops.size()) { + MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range."; + } + auto type = ops[iter_ops]->type(); if (type == MATMUL) { - PrepareMatMul(graph, ops, iter_ops, iter_nodes, iter_op_inputs, s); + return PrepareMatMul(graph, ops, iter_ops, iter_op_inputs); } else if ((type == MAXPOOL) || (type == SIMPLE_MEAN) || (type == TENSOR_ADD)) { - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_n)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_c)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_h)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].apply.arguments[iter_op_inputs].tensor_str.str_w)); + return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs); } else if (type == CONV2D) { - PrepareConv2D(graph, iter_nodes, iter_op_inputs, s); + return PrepareConv2D(graph, iter_ops, iter_op_inputs); } else if (type == BIAS_ADD) { - PrepareBiasAdd(graph, iter_nodes, iter_op_inputs, s); + return PrepareBiasAdd(graph, iter_ops, iter_op_inputs); } else if (type == RESHAPE) { - s.push_back(1); - s.push_back(1); - s.push_back(1); - s.push_back(1); + return MakeOriginalStrategy(ops, iter_ops, iter_op_inputs); } else if (type == RELU) { - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_n)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_c)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_h)); - s.push_back(static_cast(1.0 / graph->nodes[iter_nodes].tensor_parm.tensor_str.str_w)); + return MakeRecSearchStrategy(graph, iter_ops, iter_op_inputs); } else if (type == BATCH_NORM || (type == FUSE_BATCH_NORM)) { - PrepareBN(graph, iter_nodes, iter_op_inputs, s); + return PrepareBN(graph, iter_ops, iter_op_inputs); } else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) { - PrepareSparse(iter_op_inputs, s); + return PrepareSparse(iter_op_inputs); } else { - RefillOrigin(ops, iter_ops, iter_op_inputs, s); + return MakeDataParallelStrategy(ops, iter_ops, iter_op_inputs); } - return s; } -void MaskNoSupportedOps(const std::shared_ptr graph) { +void MaskSpecialOps(std::shared_ptr graph) { size_t iter_nodes = graph->nodes.size(); for (size_t i = 0; i < iter_nodes; i++) { - if (0 == graph->nodes[i].info) { - Graph::NodeType &node = graph->nodes[i]; + Graph::NodeType &node = graph->nodes[i]; - if (node.apply.op_type == 1) { // For Convolution - // cover input tensor strategy - node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); - node.apply.arguments[0].tensor_str.str_c = 1; - node.apply.arguments[0].tensor_str.str_h = 1; - node.apply.arguments[0].tensor_str.str_w = 1; - // cover filter tensor strategy - node.apply.arguments[1].tensor_str.str_n = 1; - node.apply.arguments[1].tensor_str.str_c = 1; - node.apply.arguments[1].tensor_str.str_h = 1; - node.apply.arguments[1].tensor_str.str_w = 1; - } else if (node.apply.op_type == 8) { // For BN - node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); - node.apply.arguments[0].tensor_str.str_c = 1; - node.apply.arguments[0].tensor_str.str_h = 1; - node.apply.arguments[0].tensor_str.str_w = 1; - // cover 1-d argument blobs - node.apply.arguments[1].tensor_str.str_w = 1; - node.apply.arguments[2].tensor_str.str_w = 1; - node.apply.arguments[3].tensor_str.str_w = 1; - node.apply.arguments[4].tensor_str.str_w = 1; - } else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits - node.tensor_parm.tensor_str.str_h = 1.0 / static_cast(g_device_manager->DeviceNum()); - node.tensor_parm.tensor_str.str_w = 1; - } + if (node.apply.op_type == 1) { // For Convolution + // cover input tensor strategy + node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); + node.apply.arguments[0].tensor_str.str_c = 1; + node.apply.arguments[0].tensor_str.str_h = 1; + node.apply.arguments[0].tensor_str.str_w = 1; + // cover filter tensor strategy + node.apply.arguments[1].tensor_str.str_n = 1; + node.apply.arguments[1].tensor_str.str_c = 1; + node.apply.arguments[1].tensor_str.str_h = 1; + node.apply.arguments[1].tensor_str.str_w = 1; + } else if (node.apply.op_type == 8) { // For BN + node.apply.arguments[0].tensor_str.str_n = 1.0 / static_cast(g_device_manager->DeviceNum()); + node.apply.arguments[0].tensor_str.str_c = 1; + node.apply.arguments[0].tensor_str.str_h = 1; + node.apply.arguments[0].tensor_str.str_w = 1; + // cover 1-d argument blobs + node.apply.arguments[1].tensor_str.str_n = 1; + node.apply.arguments[2].tensor_str.str_c = 1; + node.apply.arguments[3].tensor_str.str_h = 1; + node.apply.arguments[4].tensor_str.str_w = 1; + } else if (node.apply.op_type == 4 || node.apply.op_type == 9) { // For SparseSoftmaxCrossEntropyWithLogits + node.tensor_parm.tensor_str.str_h = 1.0 / static_cast(g_device_manager->DeviceNum()); + node.tensor_parm.tensor_str.str_w = 1; } } } diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h index 4abef843a8..f3274e1440 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_generate_strategy.h @@ -27,29 +27,28 @@ namespace mindspore { namespace parallel { -void GenerateStrategy(const std::shared_ptr graph, std::vector> ops, - const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list); -void PrepareMatMul(const std::shared_ptr graph, const std::vector> &ops, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs, std::vector s); -void PrepareConv2D(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s); -void PrepareBiasAdd(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s); -void PrepareBN(const std::shared_ptr graph, const size_t iter_nodes, const size_t iter_op_inputs, - std::vector s); -void PrepareSparse(const size_t iter_op_inputs, std::vector s); -void RefillOrigin(const std::vector> &ops, const size_t iter_ops, - const size_t iter_op_inputs, std::vector s); -std::vector PrepareStrategy(const std::shared_ptr graph, - const std::vector> &ops, const std::string &type, - const size_t iter_ops, const size_t iter_nodes, const size_t iter_op_inputs); -size_t IterNodes(const std::shared_ptr> ops_nodes_list, - const std::shared_ptr> index_list, - const std::shared_ptr>> eli_list, const size_t iter_ops, - size_t iter_nodes); -void MaskNoSupportedOps(const std::shared_ptr graph); +void GenerateStrategy(std::shared_ptr graph, bool mask_special_ops, + const std::vector> &ops); +std::vector PrepareMatMul(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareConv2D(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareBiasAdd(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareBN(const std::shared_ptr &graph, const size_t iter_nodes, + const size_t iter_op_inputs); +std::vector PrepareSparse(const size_t iter_op_inputs); +std::vector MakeOriginalStrategy(const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs); +std::vector MakeRecSearchStrategy(const std::shared_ptr &graph, const size_t iter_ops, + const size_t iter_op_inputs); +std::vector MakeDataParallelStrategy(const std::vector> &ops, + const size_t iter_ops, const size_t iter_op_inputs); +std::vector PrepareStrategy(const std::shared_ptr &graph, + const std::vector> &ops, const size_t iter_ops, + const size_t iter_op_inputs); +void MaskSpecialOps(std::shared_ptr graph); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 1d52eac82d..e0190d7e93 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -931,12 +931,9 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const } std::shared_ptr> ops_nodes_list(new std::vector); - std::shared_ptr> index_list(new std::vector); - std::shared_ptr>> eli_list(new std::vector>); std::shared_ptr graph = ParseGraph(ops, input_tensor_names, ops_nodes_list); - graph = EliminateGraph(graph, eli_list, index_list); size_t num_device = g_device_manager->DeviceNum(); if (PartitionForAllDevices(num_device, graph) == SUCCESS) { MS_LOG(INFO) << "Partition Success With " << num_device << " devices."; @@ -945,7 +942,8 @@ Status ParallelStrategyRecSearch(const std::vector &all_nodes, const return FAILED; } - GenerateStrategy(graph, ops, ops_nodes_list, index_list, eli_list); + bool mask_special_ops = true; + GenerateStrategy(graph, mask_special_ops, ops); if (entire_costgraph->InitSelectedStrategy() == SUCCESS) { MS_LOG(INFO) << "Init selected strategy succeeded."; From 60df3691006cd39c8913059fd5f8609382d5b3ff Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Mon, 20 Apr 2020 11:31:20 -0400 Subject: [PATCH 025/142] Fixup py Normalize doc: takes input CHW --- mindspore/dataset/transforms/vision/py_transforms.py | 2 +- tests/ut/python/dataset/test_normalizeOp.py | 5 +++-- tests/ut/python/dataset/test_random_color_adjust.py | 4 ---- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index f5ab5d873b..8d81f8f3b0 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -220,7 +220,7 @@ class Decode: class Normalize: """ - Normalize the input Numpy image array of shape (H, W, C) with the given mean and standard deviation. + Normalize the input Numpy image array of shape (C, H, W) with the given mean and standard deviation. The values of the array need to be in range [0.0, 1.0]. diff --git a/tests/ut/python/dataset/test_normalizeOp.py b/tests/ut/python/dataset/test_normalizeOp.py index 1abee96173..c080b00105 100644 --- a/tests/ut/python/dataset/test_normalizeOp.py +++ b/tests/ut/python/dataset/test_normalizeOp.py @@ -15,7 +15,7 @@ import mindspore.dataset.transforms.vision.c_transforms as vision import numpy as np - +import matplotlib.pyplot as plt import mindspore.dataset as ds from mindspore import log as logger @@ -114,6 +114,7 @@ def test_decode_op(): # plt.subplot(131) # plt.imshow(image) # plt.title("DE image") + # plt.show() num_iter += 1 @@ -138,8 +139,8 @@ def test_decode_normalize_op(): # plt.subplot(131) # plt.imshow(image) # plt.title("DE image") + # plt.show() num_iter += 1 - break if __name__ == "__main__": diff --git a/tests/ut/python/dataset/test_random_color_adjust.py b/tests/ut/python/dataset/test_random_color_adjust.py index 57c77caf81..dcb7cd48ac 100644 --- a/tests/ut/python/dataset/test_random_color_adjust.py +++ b/tests/ut/python/dataset/test_random_color_adjust.py @@ -182,8 +182,6 @@ def test_random_color_jitter_op_saturation(): ] transform = py_vision.ComposeOp(transforms) data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) - # data2 = data2.map(input_columns=["image"], operations=decode_op) - # data2 = data2.map(input_columns=["image"], operations=c_vision.Decode()) data2 = data2.map(input_columns=["image"], operations=transform()) num_iter = 0 @@ -220,8 +218,6 @@ def test_random_color_jitter_op_hue(): # First dataset data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"], shuffle=False) decode_op = c_vision.Decode() - # channel_swap_op = c_vision.ChannelSwap() - # change_mode_op = c_vision.ChangeMode() random_jitter_op = c_vision.RandomColorAdjust((1, 1), (1, 1), (1, 1), (0.2, 0.2)) From e1b109e8b8dc062980357c25c90c99c78f429896 Mon Sep 17 00:00:00 2001 From: jiangzhiwen Date: Mon, 20 Apr 2020 15:24:42 +0800 Subject: [PATCH 026/142] optimize skip dataset op --- .../dataset/engine/datasetops/skip_op.cc | 36 ++++++++----------- tests/ut/python/dataset/test_skip.py | 18 ++++++++-- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc index 90c160b5bf..d851f2c699 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/skip_op.cc @@ -67,9 +67,10 @@ Status SkipOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work } std::unique_ptr buf; + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); + // Drop first max_skips_ rows while (skip_count_ < max_skips_) { - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); if (buf->eoe() || buf->eof()) { break; } @@ -77,31 +78,24 @@ Status SkipOp::GetNextBuffer(std::unique_ptr *p_buffer, int32_t work // Consider the rows of buffer more than 1 TensorRow drop_row; int row_num = buf->NumRows(); - for (int i = 0; i < row_num; i++) { + int drop_num = row_num + skip_count_ < max_skips_ ? row_num : max_skips_ - skip_count_; + skip_count_ += drop_num; + for (int i = 0; i < drop_num; i++) { RETURN_IF_NOT_OK(buf->PopRow(&drop_row)); - if (++skip_count_ == max_skips_) { - break; - } } - } - - // If buffer is none or the rows of buffer is 0, - // then get a buffer from child. - if (!buf || buf->NumRows() == 0) { - if (buf && buf->eof()) { - *p_buffer = std::move(buf); - return Status::OK(); + if (buf->NumRows() == 0) { + RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); } - RETURN_IF_NOT_OK(child_[0]->GetNextBuffer(&buf, worker_id, true)); } - // Handling eoe and eof - if (buf->eoe() || buf->eof()) { + // Handling eoe + if (buf->eoe()) { RETURN_IF_NOT_OK(EoeReceived(worker_id)); - if (state_ == OpState::kDeOpIdle) { - *p_buffer = std::move(buf); - return Status::OK(); - } + } + + // Handling eof + if (buf->eof()) { + RETURN_IF_NOT_OK(EofReceived(worker_id)); } *p_buffer = std::move(buf); @@ -125,7 +119,7 @@ Status SkipOp::operator()() { RETURN_STATUS_UNEXPECTED("Logic error. SkipOp is a // Base-class override for handling cases when an eof is received. Status SkipOp::EofReceived(int32_t worker_id) { - MS_LOG(INFO) << "Skip operator EOF received, do nothing now."; + MS_LOG(DEBUG) << "Skip operator EOF received, do nothing now."; return Status::OK(); } } // namespace dataset diff --git a/tests/ut/python/dataset/test_skip.py b/tests/ut/python/dataset/test_skip.py index bea7db4e05..59893f6ded 100644 --- a/tests/ut/python/dataset/test_skip.py +++ b/tests/ut/python/dataset/test_skip.py @@ -22,7 +22,11 @@ from mindspore import log as logger DATA_DIR_TF2 = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] SCHEMA_DIR_TF2 = "../data/dataset/test_tf_file_3_images/datasetSchema.json" + def test_tf_skip(): + """ + a simple skip operation. + """ data1 = ds.TFRecordDataset(DATA_DIR_TF2, SCHEMA_DIR_TF2, shuffle=False) resize_height, resize_width = 32, 32 @@ -37,11 +41,15 @@ def test_tf_skip(): num_iter += 1 assert num_iter == 1 + def generator_md(): - # Create a dataset with [0, 1, 2, 3, 4] + """ + create a dataset with [0, 1, 2, 3, 4] + """ for i in range(5): yield (np.array([i]), ) + def test_generator_skip(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -53,6 +61,7 @@ def test_generator_skip(): buf.append(data[0][0]) assert len(buf) == 2 + def test_skip_1(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -64,6 +73,7 @@ def test_skip_1(): buf.append(data[0][0]) assert len(buf) == 0 + def test_skip_2(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -75,6 +85,7 @@ def test_skip_2(): buf.append(data[0][0]) assert len(buf) == 5 + def test_skip_repeat_1(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -89,6 +100,7 @@ def test_skip_repeat_1(): buf.append(data[0][0]) assert len(buf) == 7 + def test_skip_repeat_2(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -103,6 +115,7 @@ def test_skip_repeat_2(): buf.append(data[0][0]) assert len(buf) == 4 + def test_skip_repeat_3(): ds1 = ds.GeneratorDataset(generator_md, ["data"]) @@ -120,6 +133,7 @@ def test_skip_repeat_3(): buf.append(data[0][0]) assert len(buf) == 6 + if __name__ == "__main__": test_tf_skip() test_generator_skip() @@ -127,4 +141,4 @@ if __name__ == "__main__": test_skip_2() test_skip_repeat_1() test_skip_repeat_2() - test_skip_repeat_3() \ No newline at end of file + test_skip_repeat_3() From 0de05b39491f0655ce0149d8a6b3567809331937 Mon Sep 17 00:00:00 2001 From: candanzg Date: Sat, 18 Apr 2020 16:20:31 +0800 Subject: [PATCH 027/142] [bug] fixed bool check for cast op Signed-off-by: candanzg --- .../ccsrc/operator/composite/do_signature.cc | 13 +++++++++++++ mindspore/ops/operations/array_ops.py | 2 +- tests/ut/python/ops/test_math_ops.py | 16 ++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index a4a26377f5..70fc0f591c 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -137,6 +137,19 @@ void DoAutoCast(const std::vector& signature, const abstract::Abstrac if (it == dst_type.end() || it->second == i || !arg_value->isa()) { continue; } + // When scalar is of bool type, the type of tensor must also be of bool type, + // otherwise the cast operator will not be added. + auto scalar = arg_value->cast(); + auto scalar_type = scalar->BuildType(); + MS_EXCEPTION_IF_NULL(scalar_type); + if (scalar_type->type_id() == kNumberTypeBool) { + auto tensor = args_spec_list[it->second]->cast(); + auto tensor_type = tensor->element()->BuildType(); + MS_EXCEPTION_IF_NULL(tensor_type); + if (tensor_type->type_id() != kNumberTypeBool) { + continue; + } + } // get source node for cast AnfNodePtr source_node = (*op_inputs)[it->second + 1]; (*op_inputs)[i + 1] = DoCast((*op_inputs)[i + 1], source_node, graph); diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2e03676a4a..b4c4796d5e 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -745,7 +745,7 @@ class Fill(PrimitiveWithInfer): out = { 'value': Tensor(ret), 'shape': dims['value'], - 'dtype': x_nptype, + 'dtype': x_dtype, } return out diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index 8b7f627e81..7f8717d4e6 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -30,6 +30,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config from ....mindspore_test_framework.pipeline.forward.verify_exception \ import pipeline_for_verify_exception_for_case_by_case_config +import pytest # pylint: disable=W0613 @@ -81,14 +82,29 @@ def test_sqrt(): assert np.all(result.asnumpy() == expect) +class PowNet(nn.Cell): + def __init__(self): + super(PowNet, self).__init__() + self.pow = P.Pow() + + def construct(self, x, y): + return self.pow(x, y) + + def test_pow(): """ test_pow """ input_tensor = Tensor(np.array([[2, 2], [3, 3]])) power = Tensor(np.array(3.0, np.int64)) + power2 = Tensor(np.array(True, np.bool)) testpow = P.Pow() expect = np.array([[8, 8], [27, 27]]) result = testpow(input_tensor, power) assert np.all(result.asnumpy() == expect) + net = PowNet() + with pytest.raises(TypeError): + net(input_tensor, True) + with pytest.raises(TypeError): + net(input_tensor, power2) def test_exp(): From f1542a90a35d71b74c80a1cf21e3d57e5be57d18 Mon Sep 17 00:00:00 2001 From: liyong Date: Tue, 14 Apr 2020 20:50:44 +0800 Subject: [PATCH 028/142] add pk sampler --- .../ccsrc/dataset/api/python_bindings.cc | 20 +- .../engine/datasetops/source/mindrecord_op.cc | 5 +- .../engine/datasetops/source/mindrecord_op.h | 3 +- .../mindrecord/include/common/shard_utils.h | 2 + .../ccsrc/mindrecord/include/shard_category.h | 24 ++- .../ccsrc/mindrecord/include/shard_operator.h | 2 + .../mindrecord/include/shard_pk_sample.h | 49 +++++ .../ccsrc/mindrecord/include/shard_reader.h | 16 +- .../ccsrc/mindrecord/include/shard_sample.h | 3 + .../ccsrc/mindrecord/include/shard_shuffle.h | 3 +- .../ccsrc/mindrecord/include/shard_task.h | 4 +- mindspore/ccsrc/mindrecord/io/shard_reader.cc | 175 +++++++++++++++--- .../ccsrc/mindrecord/meta/shard_category.cc | 25 ++- .../ccsrc/mindrecord/meta/shard_pk_sample.cc | 46 +++++ .../ccsrc/mindrecord/meta/shard_sample.cc | 18 ++ .../ccsrc/mindrecord/meta/shard_shuffle.cc | 30 +-- mindspore/ccsrc/mindrecord/meta/shard_task.cc | 40 ++-- mindspore/dataset/engine/datasets.py | 11 +- mindspore/dataset/engine/samplers.py | 2 + .../cpp/mindrecord/ut_shard_operator_test.cc | 52 ++++++ .../testImageNetData/annotation_sampler.txt | 10 + .../dataset/test_minddataset_sampler.py | 79 ++++++-- .../ut/python/dataset/test_serdes_dataset.py | 2 +- 23 files changed, 540 insertions(+), 81 deletions(-) create mode 100644 mindspore/ccsrc/mindrecord/include/shard_pk_sample.h create mode 100644 mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc create mode 100644 tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 214ce4c153..9865396a7d 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -60,6 +60,7 @@ #include "dataset/kernels/data/to_float16_op.h" #include "dataset/util/random.h" #include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_sample.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -152,9 +153,14 @@ void bindDatasetOps(py::module *m) { }); (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", [](const std::string &path) { + .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { int64_t count = 0; - THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count)); + std::shared_ptr op; + if (py::hasattr(sampler, "_create_for_minddataset")) { + auto create = sampler.attr("_create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); return count; }); @@ -435,6 +441,16 @@ void bindSamplerOps(py::module *m) { (void)py::class_>( *m, "MindrecordSubsetRandomSampler") .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + (void)py::class_>( + *m, "MindrecordPkSampler") + .def(py::init([](int64_t kVal, bool shuffle) { + if (shuffle == true) { + return std::make_shared("label", kVal, std::numeric_limits::max(), + GetSeed()); + } else { + return std::make_shared("label", kVal); + } + })); (void)py::class_>(*m, "WeightedRandomSampler") .def(py::init, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index fbb772af59..72dee6f2e6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { return Status::OK(); } -Status MindRecordOp::CountTotalRows(const std::string dataset_path, int64_t *count) { +Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, + int64_t *count) { std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, count); + MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); if (rc == MSRStatus::FAILED) { RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h index aca5c86c2c..899919e529 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h @@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp { int32_t num_rows() const { return num_rows_; } // Getter method - static Status CountTotalRows(const std::string dataset_path, int64_t *count); + static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, + int64_t *count); // Getter method int32_t rows_per_buffer() const { return rows_per_buffer_; } diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h index d31037c8ad..3af4d7f891 100644 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h @@ -72,6 +72,8 @@ enum ShardType { enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; +enum ShuffleType { kShuffleCategory, kShuffleSample }; + const double kEpsilon = 1e-7; const int kThreadNumber = 14; diff --git a/mindspore/ccsrc/mindrecord/include/shard_category.h b/mindspore/ccsrc/mindrecord/include/shard_category.h index b8a7611540..b2fe18fbac 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/mindrecord/include/shard_category.h @@ -17,6 +17,8 @@ #ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ #define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#include +#include #include #include #include @@ -26,16 +28,34 @@ namespace mindspore { namespace mindrecord { class ShardCategory : public ShardOperator { public: - explicit ShardCategory(const std::vector> &categories); + explicit ShardCategory(const std::vector> &categories, + int64_t num_elements = std::numeric_limits::max(), bool replacement = false); + + ShardCategory(const std::string &category_field, int64_t num_elements, + int64_t num_categories = std::numeric_limits::max(), bool replacement = false); ~ShardCategory() override{}; - const std::vector> &get_categories() const; + const std::vector> &get_categories() const { return categories_; } + + const std::string GetCategoryField() const { return category_field_; } + + int64_t GetNumElements() const { return num_elements_; } + + int64_t GetNumCategories() const { return num_categories_; } + + bool GetReplacement() const { return replacement_; } MSRStatus execute(ShardTask &tasks) override; + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + private: std::vector> categories_; + std::string category_field_; + int64_t num_elements_; + int64_t num_categories_; + bool replacement_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h index 9f302e5321..7476660a70 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_operator.h @@ -43,6 +43,8 @@ class ShardOperator { virtual MSRStatus execute(ShardTask &tasks) = 0; virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } + + virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h new file mode 100644 index 0000000000..df3888dad4 --- /dev/null +++ b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ + +#include +#include +#include +#include +#include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_shuffle.h" +#include "mindrecord/include/shard_category.h" + +namespace mindspore { +namespace mindrecord { +class ShardPkSample : public ShardCategory { + public: + ShardPkSample(const std::string &category_field, int64_t num_elements); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); + + ~ShardPkSample() override{}; + + MSRStatus suf_execute(ShardTask &tasks) override; + + private: + bool shuffle_; + std::shared_ptr shuffle_op_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 5548473cd7..3263b2006d 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -115,9 +115,10 @@ class ShardReader { /// \brief get the number of rows in database /// \param[in] file_path the path of ONE file, any file in dataset is fine + /// \param[in] op smart pointer refer to ShardCategory or ShardSample object /// \param[out] count # of rows /// \return MSRStatus the status of MSRStatus - MSRStatus CountTotalRows(const std::string &file_path, int64_t *count); + MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr &op, int64_t *count); /// \brief shuffle task with incremental seed /// \return void @@ -197,6 +198,9 @@ class ShardReader { /// \brief get NLP flag bool get_nlp_flag(); + /// \brief get all classes + MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); + protected: /// \brief sqlite call back function static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); @@ -249,8 +253,8 @@ class ShardReader { const std::vector> &operators); /// \brief create category-applied task list - int CreateTasksByCategory(const std::vector> &row_group_summary, - const std::vector> &operators); + MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op); /// \brief create task list in row-reader mode MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, @@ -284,6 +288,12 @@ class ShardReader { MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); + /// \brief get classes in one shard + void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); + + /// \brief get number of classes + int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); + protected: uint64_t header_size_; // header size uint64_t page_size_; // page size diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index 15353fd0ff..b16fc5cc4f 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -41,8 +41,11 @@ class ShardSample : public ShardOperator { const std::pair get_partitions() const; MSRStatus execute(ShardTask &tasks) override; + MSRStatus suf_execute(ShardTask &tasks) override; + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + private: int numerator_; int denominator_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h index 464881aa7a..027a5ad527 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h @@ -24,7 +24,7 @@ namespace mindspore { namespace mindrecord { class ShardShuffle : public ShardOperator { public: - explicit ShardShuffle(uint32_t seed = 0); + explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); ~ShardShuffle() override{}; @@ -32,6 +32,7 @@ class ShardShuffle : public ShardOperator { private: uint32_t shuffle_seed_; + ShuffleType shuffle_type_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_task.h b/mindspore/ccsrc/mindrecord/include/shard_task.h index 30ea352ef3..b276b5150f 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/mindrecord/include/shard_task.h @@ -41,7 +41,9 @@ class ShardTask { std::tuple, std::vector, json> &get_task_by_id(size_t id); - static ShardTask Combine(std::vector &category_tasks); + std::tuple, std::vector, json> &get_random_task(); + + static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); uint32_t categories = 1; diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index fd3fede5a2..9cd02d9120 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -315,6 +315,43 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); } +MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); + if (SUCCESS != ret.first) { + return FAILED; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count_); + for (int x = 0; x < shard_count_; x++) { + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count_; x++) { + threads[x].join(); + } + return SUCCESS; +} + +void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, + std::set &categories) { + if (nullptr == db) { + return; + } + std::vector> columns; + char *errmsg = nullptr; + int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); + if (ret != SQLITE_OK) { + sqlite3_free(errmsg); + sqlite3_close(db); + MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; + return; + } + MS_LOG(INFO) << "Get" << static_cast(columns.size()) << " records from shard " << shard_id << " index."; + for (int i = 0; i < static_cast(columns.size()); ++i) { + categories.emplace(columns[i][0]); + } +} + ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; std::vector>> offsets(shard_count_, std::vector>{}); @@ -667,11 +704,64 @@ MSRStatus ShardReader::Finish() { return SUCCESS; } -MSRStatus ShardReader::CountTotalRows(const std::string &file_path, int64_t *count) { +int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) { + ShardHeader sh = ShardHeader(); + if (sh.Build(file_path) == FAILED) { + return -1; + } + auto header = std::make_shared(sh); + auto file_paths = header->get_shard_addresses(); + auto shard_count = file_paths.size(); + auto index_fields = header->get_fields(); + + std::map map_schema_id_fields; + for (auto &field : index_fields) { + map_schema_id_fields[field.second] = field.first; + } + auto ret = + ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); + if (SUCCESS != ret.first) { + return -1; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count); + std::set categories; + for (int x = 0; x < shard_count; x++) { + sqlite3 *db = nullptr; + int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + if (SQLITE_OK != rc) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return -1; + } + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count; x++) { + threads[x].join(); + } + return categories.size(); +} + +MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr &op, + int64_t *count) { if (Init(file_path) == FAILED) { return FAILED; } - *count = num_rows_; + int64_t num_samples = num_rows_; + if (std::dynamic_pointer_cast(op)) { + auto category_op = std::dynamic_pointer_cast(op); + std::string category_field = category_op->GetCategoryField(); + auto num_classes = GetNumClasses(file_path, category_field); + num_samples = category_op->GetNumSamples(num_rows_, num_classes); + } else if (std::dynamic_pointer_cast(op)) { + num_samples = op->GetNumSamples(num_rows_, 0); + } else { + } + if (-1 == num_samples) { + MS_LOG(ERROR) << "Failed to get dataset size."; + return FAILED; + } + *count = num_samples; return SUCCESS; } @@ -793,6 +883,8 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); } } + + MS_LOG(INFO) << "Launch read thread successfully."; return SUCCESS; } @@ -828,44 +920,67 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators) { +MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op) { vector columns = GetAllColumns(); CheckIfColumnInIndex(columns); - int category_operator = -1; - for (uint32_t i = 0; i < operators.size(); ++i) { - const auto &op = operators[i]; - if (std::dynamic_pointer_cast(op)) category_operator = static_cast(i); + auto category_op = std::dynamic_pointer_cast(op); + auto categories = category_op->get_categories(); + int64_t num_elements = category_op->GetNumElements(); + if (num_elements <= 0) { + MS_LOG(ERROR) << "Parameter num_element is not positive"; + return FAILED; + } + if (categories.empty() == true) { + std::string category_field = category_op->GetCategoryField(); + int64_t num_categories = category_op->GetNumCategories(); + if (num_categories <= 0) { + MS_LOG(ERROR) << "Parameter num_categories is not positive"; + return FAILED; + } + std::set categories_set; + auto ret = GetAllClasses(category_field, categories_set); + if (SUCCESS != ret) { + return FAILED; + } + int i = 0; + for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { + categories.emplace_back(category_field, *it); + i++; + } } - - if (category_operator == -1) return category_operator; - - auto categories = std::dynamic_pointer_cast(operators[category_operator])->get_categories(); - // Generate task list, a task will create a batch std::vector categoryTasks(categories.size()); for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { + int category_index = 0; for (const auto &rg : row_group_summary) { + if (category_index >= num_elements) break; auto shard_id = std::get<0>(rg); auto group_id = std::get<1>(rg); auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns); if (SUCCESS != std::get<0>(details)) { - return -2; + return FAILED; } auto offsets = std::get<4>(details); auto number_of_rows = offsets.size(); for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { - categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], - std::get<5>(details)[iStart]); + if (category_index < num_elements) { + categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], + std::get<5>(details)[iStart]); + category_index++; + } } } MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; } - tasks_ = ShardTask::Combine(categoryTasks); - return category_operator; + tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); + if (SUCCESS != (*category_op)(tasks_)) { + return FAILED; + } + return SUCCESS; } MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, @@ -896,14 +1011,26 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, const std::vector> &operators) { if (block_reader_) { - CreateTasksByBlock(row_group_summary, operators); + if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { + return FAILED; + } } else { - int category_operator = CreateTasksByCategory(row_group_summary, operators); - if (category_operator == -1) { - CreateTasksByRow(row_group_summary, operators); + int category_operator = -1; + for (uint32_t i = 0; i < operators.size(); ++i) { + const auto &op = operators[i]; + if (std::dynamic_pointer_cast(op)) { + category_operator = static_cast(i); + break; + } } - if (category_operator == -2) { - return FAILED; + if (-1 == category_operator) { + if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { + return FAILED; + } + } else { + if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { + return FAILED; + } } } diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc index 859a3b343f..80816e7a79 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_category.cc @@ -18,11 +18,30 @@ namespace mindspore { namespace mindrecord { -ShardCategory::ShardCategory(const std::vector> &categories) - : categories_(categories) {} +ShardCategory::ShardCategory(const std::vector> &categories, int64_t num_elements, + bool replacement) + : categories_(categories), + category_field_(""), + num_elements_(num_elements), + num_categories_(0), + replacement_(replacement) {} -const std::vector> &ShardCategory::get_categories() const { return categories_; } +ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, + bool replacement) + : categories_({}), + category_field_(category_field), + num_elements_(num_elements), + num_categories_(num_categories), + replacement_(replacement) {} MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } + +int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (dataset_size == 0) return dataset_size; + if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) { + return std::min(num_categories_, num_classes) * num_elements_; + } + return -1; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc new file mode 100644 index 0000000000..8e2e892e63 --- /dev/null +++ b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindrecord/include/shard_pk_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) + : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, + uint32_t seed) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement +} + +MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) { + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc index ef627b0c09..a9cfce0d01 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc @@ -56,6 +56,24 @@ ShardSample::ShardSample(const std::vector &indices, uint32_t seed) shuffle_op_ = std::make_shared(seed); } +int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (sampler_type_ == kCustomTopNSampler) { + return no_of_samples_; + } + + if (sampler_type_ == kCustomTopPercentSampler) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } + if (sampler_type_ == kSubsetRandomSampler) { + return indices_.size(); + } + return -1; +} + const std::pair ShardSample::get_partitions() const { if (numerator_ == 1 && denominator_ > 1) { return std::pair(denominator_, partition_id_); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc index f8ad2c341d..757dcb7b74 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc @@ -20,25 +20,33 @@ namespace mindspore { namespace mindrecord { -ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} +ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) + : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} MSRStatus ShardShuffle::execute(ShardTask &tasks) { if (tasks.categories < 1) { return FAILED; } - uint32_t individual_size = tasks.Size() / tasks.categories; - std::vector> new_permutations(tasks.categories, std::vector(individual_size)); - for (uint32_t i = 0; i < tasks.categories; i++) { - for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); - std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); - } - shuffle_seed_++; - tasks.permutation_.clear(); - for (uint32_t j = 0; j < individual_size; j++) { + if (shuffle_type_ == kShuffleSample) { + if (tasks.permutation_.empty() == true) { + tasks.MakePerm(); + } + std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) + uint32_t individual_size = tasks.Size() / tasks.categories; + std::vector> new_permutations(tasks.categories, std::vector(individual_size)); for (uint32_t i = 0; i < tasks.categories; i++) { - tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); + std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); + } + tasks.permutation_.clear(); + for (uint32_t j = 0; j < individual_size; j++) { + for (uint32_t i = 0; i < tasks.categories; i++) { + tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + } } } + shuffle_seed_++; return SUCCESS; } } // namespace mindrecord diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc index 3744d881a4..be566d1601 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_task.cc @@ -35,8 +35,6 @@ void ShardTask::InsertTask(int shard_id, int group_id, const std::vector, std::vector, json> task) { @@ -44,9 +42,6 @@ void ShardTask::InsertTask(std::tuple, std::vector(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() << ", size of task_list_: " << task_list_.size() << "."; task_list_.push_back(std::move(task)); - MS_LOG(DEBUG) << "Out of insert task, shard_id: " << std::get<0>(std::get<0>(task)) - << ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() - << ", size of task_list_: " << task_list_.size() << "."; } void ShardTask::PopBack() { task_list_.pop_back(); } @@ -69,18 +64,39 @@ std::tuple, std::vector, json> &ShardTask::get_ta return task_list_[id]; } -ShardTask ShardTask::Combine(std::vector &category_tasks) { +std::tuple, std::vector, json> &ShardTask::get_random_task() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + return task_list_[dis(gen)]; +} +ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { ShardTask res; if (category_tasks.empty()) return res; auto total_categories = category_tasks.size(); res.categories = static_cast(total_categories); - auto minTasks = category_tasks[0].Size(); - for (uint32_t i = 1; i < total_categories; i++) { - minTasks = std::min(minTasks, category_tasks[i].Size()); - } - for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + if (replacement == false) { + auto minTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + minTasks = std::min(minTasks, category_tasks[i].Size()); + } + for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + for (uint32_t i = 0; i < total_categories; i++) { + res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast(task_no)))); + } + } + } else { + auto maxTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + maxTasks = std::max(maxTasks, category_tasks[i].Size()); + } + if (num_elements != std::numeric_limits::max()) { + maxTasks = static_cast(num_elements); + } for (uint32_t i = 0; i < total_categories; i++) { - res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast(task_no)))); + for (uint32_t j = 0; j < maxTasks; j++) { + res.InsertTask(category_tasks[i].get_random_task()); + } } } return res; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 5b3c0f1503..28697a6c43 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1882,7 +1882,8 @@ class MindDataset(SourceDataset): block_reader (bool, optional): Whether read data by block mode (default=False). sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, sampler is exclusive - with shuffle and block_reader). Support list: SubsetRandomSampler. + with shuffle and block_reader). Support list: SubsetRandomSampler, + PkSampler Raises: ValueError: If num_shards is specified but shard_id is None. @@ -1915,8 +1916,10 @@ class MindDataset(SourceDataset): if block_reader is True: logger.warning("WARN: global shuffle is not used.") - if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False: - raise ValueError("the sampler is not supported yet.") + if sampler is not None: + if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ + isinstance(sampler, samplers.PKSampler) is False: + raise ValueError("the sampler is not supported yet.") # sampler exclusive if block_reader is True and sampler is not None: @@ -1952,7 +1955,7 @@ class MindDataset(SourceDataset): Number, number of batches. """ - num_rows = MindRecordOp.get_num_rows(self.dataset_file) + num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler) if self.partitions is not None and self.partitions[0] > 0: if num_rows % self.partitions[0] == 0: num_rows = num_rows // self.partitions[0] diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 421a03ab8d..82759989cb 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -184,6 +184,8 @@ class PKSampler(BuiltinSampler): def create(self): return cde.PKSampler(self.num_val, self.shuffle) + def _create_for_minddataset(self): + return cde.MindrecordPkSampler(self.num_val, self.shuffle) class RandomSampler(BuiltinSampler): """ diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 549e2140f4..bfd49069b2 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -25,6 +25,7 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "mindrecord/include/shard_category.h" +#include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_reader.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" @@ -146,6 +147,57 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { ASSERT_TRUE(i <= 10); } +TEST_F(TestShardOperator, TestShardPkSamplerBasic) { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); + + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name", "label"}; + + std::vector> ops; + ops.push_back(std::make_shared("label", 2)); + + ShardReader dataset; + dataset.Open(file_name, 4, column_list, ops); + dataset.Launch(); + + int i = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + i++; + } + dataset.Finish(); + ASSERT_TRUE(i == 20); +} // namespace mindrecord + +TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); + + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name", "label"}; + + std::vector> ops; + ops.push_back(std::make_shared("label", 2, 3, 0)); + + ShardReader dataset; + dataset.Open(file_name, 4, column_list, ops); + dataset.Launch(); + + int i = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + + std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + i++; + } + dataset.Finish(); + ASSERT_TRUE(i == 6); +} // namespace mindrecord + TEST_F(TestShardOperator, TestShardCategory) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); diff --git a/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt b/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt new file mode 100644 index 0000000000..fbfbba025f --- /dev/null +++ b/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt @@ -0,0 +1,10 @@ +image_00001.jpg,164 +image_00002.jpg,164 +image_00003.jpg,164 +image_00004.jpg,599 +image_00005.jpg,599 +image_00006.jpg,599 +image_00007.jpg,13 +image_00008.jpg,13 +image_00009.jpg,13 +image_00010.jpg,13 diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 3cad3877ef..584bb88041 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -46,7 +46,7 @@ def add_and_remove_cv_file(): if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) writer = FileWriter(CV_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME) + data = get_data(CV_DIR_NAME, True) cv_schema_json = {"id": {"type": "int32"}, "file_name": {"type": "string"}, "label": {"type": "int32"}, @@ -61,6 +61,59 @@ def add_and_remove_cv_file(): os.remove("{}.db".format(x)) +def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(2) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + +def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(3, None, True) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 9 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + +def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(5, None, True) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 15 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] @@ -69,8 +122,7 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -93,8 +145,7 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 6 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -117,8 +168,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 0 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -133,7 +183,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): assert num_iter == 0 -def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): +def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] num_readers = 4 @@ -141,8 +191,7 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -165,8 +214,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -181,7 +229,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): assert num_iter == 5 -def get_data(dir_name): +def get_data(dir_name, sampler=False): """ usage: get data from imagenet dataset params: @@ -191,7 +239,10 @@ def get_data(dir_name): if not os.path.isdir(dir_name): raise IOError("Directory {} not exists".format(dir_name)) img_dir = os.path.join(dir_name, "images") - ann_file = os.path.join(dir_name, "annotation.txt") + if sampler: + ann_file = os.path.join(dir_name, "annotation_sampler.txt") + else: + ann_file = os.path.join(dir_name, "annotation.txt") with open(ann_file, "r") as file_reader: lines = file_reader.readlines() diff --git a/tests/ut/python/dataset/test_serdes_dataset.py b/tests/ut/python/dataset/test_serdes_dataset.py index 7fdb0f1dde..0a6f86974b 100644 --- a/tests/ut/python/dataset/test_serdes_dataset.py +++ b/tests/ut/python/dataset/test_serdes_dataset.py @@ -243,7 +243,7 @@ def test_minddataset(add_and_remove_cv_file): assert ds1_json == ds2_json data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): num_iter += 1 From 5d467874182732e1694176c8da1505d5c58c53d5 Mon Sep 17 00:00:00 2001 From: leonwanghui Date: Tue, 21 Apr 2020 10:18:24 +0800 Subject: [PATCH 029/142] Fix the video conferencing link error in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e465f8e3e1..3de87d3fec 100644 --- a/README.md +++ b/README.md @@ -179,7 +179,7 @@ Check out how MindSpore Open Governance [works](https://gitee.com/mindspore/comm - [MindSpore Slack](https://join.slack.com/t/mindspore/shared_invite/enQtOTcwMTIxMDI3NjM0LTNkMWM2MzI5NjIyZWU5ZWQ5M2EwMTQ5MWNiYzMxOGM4OWFhZjI4M2E5OGI2YTg3ODU1ODE2Njg1MThiNWI3YmQ) - Communication platform for developers. - IRC channel at `#mindspore` (only for meeting minutes logging purpose) -- Video Conferencing: meet.jit.si +- Video Conferencing: https://meet.jit.si - Mailing-list: https://mailweb.mindspore.cn/postorius/lists ## Contributing From 24b26ee1a8a096c950f03cb1534ec1378f423e0c Mon Sep 17 00:00:00 2001 From: leonwanghui Date: Tue, 21 Apr 2020 10:20:09 +0800 Subject: [PATCH 030/142] Move args_type_check function to _checkparam.py --- mindspore/_checkparam.py | 53 ++++++++++++++----- mindspore/_extends/__init__.py | 2 +- mindspore/_extends/pynative_helper.py | 44 --------------- mindspore/context.py | 17 +++--- mindspore/parallel/_auto_parallel_context.py | 2 +- mindspore/parallel/_cost_model_context.py | 2 +- mindspore/parallel/algo_parameter_config.py | 2 +- tests/ut/python/pynative_mode/test_backend.py | 10 ++-- 8 files changed, 60 insertions(+), 72 deletions(-) delete mode 100644 mindspore/_extends/pynative_helper.py diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index e9a928461f..7b8c89351c 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -14,8 +14,9 @@ # ============================================================================ """Check parameters.""" import re +import inspect from enum import Enum -from functools import reduce +from functools import reduce, wraps from itertools import repeat from collections.abc import Iterable @@ -181,7 +182,7 @@ class Validator: @staticmethod def check_subclass(arg_name, type_, template_type, prim_name): - """Checks whether some type is sublcass of another type""" + """Checks whether some type is subclass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -240,7 +241,6 @@ class Validator: elem_types = map(_check_tensor_type, args.items()) reduce(_check_types_same, elem_types) - @staticmethod def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): """ @@ -261,7 +261,7 @@ class Validator: def _check_types_same(arg1, arg2): arg1_name, arg1_type = arg1 arg2_name, arg2_type = arg2 - excp_flag = False + except_flag = False if isinstance(arg1_type, type(mstype.tensor)) and isinstance(arg2_type, type(mstype.tensor)): arg1_type = arg1_type.element_type() arg2_type = arg2_type.element_type() @@ -271,9 +271,9 @@ class Validator: arg1_type = arg1_type.element_type() if isinstance(arg1_type, type(mstype.tensor)) else arg1_type arg2_type = arg2_type.element_type() if isinstance(arg2_type, type(mstype.tensor)) else arg2_type else: - excp_flag = True + except_flag = True - if excp_flag or arg1_type != arg2_type: + if except_flag or arg1_type != arg2_type: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') return arg1 @@ -283,11 +283,12 @@ class Validator: def check_value_type(arg_name, arg_value, valid_types, prim_name): """Checks whether a value is instance of some types.""" valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + def raise_error_msg(): """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) - msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' + msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') @@ -303,6 +304,7 @@ class Validator: def check_type_name(arg_name, arg_type, valid_types, prim_name): """Checks whether a type in some specified types""" valid_types = valid_types if isinstance(valid_types, Iterable) else (valid_types,) + def get_typename(t): return t.__name__ if hasattr(t, '__name__') else str(t) @@ -368,9 +370,9 @@ class ParamValidator: @staticmethod def check_isinstance(arg_name, arg_value, classes): - """Check arg isintance of classes""" + """Check arg isinstance of classes""" if not isinstance(arg_value, classes): - raise ValueError(f'The `{arg_name}` should be isintance of {classes}, but got {arg_value}.') + raise ValueError(f'The `{arg_name}` should be isinstance of {classes}, but got {arg_value}.') return arg_value @staticmethod @@ -384,7 +386,7 @@ class ParamValidator: @staticmethod def check_subclass(arg_name, type_, template_type, with_type_of=True): - """Check whether some type is sublcass of another type""" + """Check whether some type is subclass of another type""" if not isinstance(template_type, Iterable): template_type = (template_type,) if not any([mstype.issubclass_(type_, x) for x in template_type]): @@ -402,9 +404,9 @@ class ParamValidator: @staticmethod def check_bool(arg_name, arg_value): - """Check arg isintance of bool""" + """Check arg isinstance of bool""" if not isinstance(arg_value, bool): - raise ValueError(f'The `{arg_name}` should be isintance of bool, but got {arg_value}.') + raise ValueError(f'The `{arg_name}` should be isinstance of bool, but got {arg_value}.') return arg_value @staticmethod @@ -771,3 +773,30 @@ def _check_str_by_regular(target, reg=None, flag=re.ASCII): if re.match(reg, target, flag) is None: raise ValueError("'{}' is illegal, it should be match regular'{}' by flags'{}'".format(target, reg, flag)) return True + + +def args_type_check(*type_args, **type_kwargs): + """Check whether input data type is correct.""" + + def type_check(func): + sig = inspect.signature(func) + bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments + + @wraps(func) + def wrapper(*args, **kwargs): + nonlocal bound_types + bound_values = sig.bind(*args, **kwargs) + argument_dict = bound_values.arguments + if "kwargs" in bound_types: + bound_types = bound_types["kwargs"] + if "kwargs" in argument_dict: + argument_dict = argument_dict["kwargs"] + for name, value in argument_dict.items(): + if name in bound_types: + if value is not None and not isinstance(value, bound_types[name]): + raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) + return func(*args, **kwargs) + + return wrapper + + return type_check diff --git a/mindspore/_extends/__init__.py b/mindspore/_extends/__init__.py index 5eabfcd97c..91e1192e7e 100644 --- a/mindspore/_extends/__init__.py +++ b/mindspore/_extends/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================ """ -Extension functions. +Extension functions. Python functions that will be called in the c++ parts of MindSpore. """ diff --git a/mindspore/_extends/pynative_helper.py b/mindspore/_extends/pynative_helper.py deleted file mode 100644 index 0b93ab926b..0000000000 --- a/mindspore/_extends/pynative_helper.py +++ /dev/null @@ -1,44 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Pynative mode help module.""" -from inspect import signature -from functools import wraps - - -def args_type_check(*type_args, **type_kwargs): - """Check whether input data type is correct.""" - - def type_check(func): - sig = signature(func) - bound_types = sig.bind_partial(*type_args, **type_kwargs).arguments - - @wraps(func) - def wrapper(*args, **kwargs): - nonlocal bound_types - bound_values = sig.bind(*args, **kwargs) - argument_dict = bound_values.arguments - if "kwargs" in bound_types: - bound_types = bound_types["kwargs"] - if "kwargs" in argument_dict: - argument_dict = argument_dict["kwargs"] - for name, value in argument_dict.items(): - if name in bound_types: - if value is not None and not isinstance(value, bound_types[name]): - raise TypeError('Argument {} must be {}'.format(name, bound_types[name])) - return func(*args, **kwargs) - - return wrapper - - return type_check diff --git a/mindspore/context.py b/mindspore/context.py index 311ca745fc..f6fe8705fd 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -14,7 +14,7 @@ # ============================================================================ """ The context of mindspore, used to configure the current execution environment, -including execution mode, execution backend and other feature switchs. +including execution mode, execution backend and other feature switches. """ import os import threading @@ -22,7 +22,7 @@ from collections import namedtuple from types import FunctionType from mindspore import log as logger from mindspore._c_expression import MSContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check from mindspore.parallel._auto_parallel_context import _set_auto_parallel_context, _get_auto_parallel_context, \ _reset_auto_parallel_context @@ -38,7 +38,7 @@ def _make_directory(path: str): """Make directory.""" real_path = None if path is None or not isinstance(path, str) or path.strip() == "": - raise ValueError(f"Input path `{path}` is invaild type") + raise ValueError(f"Input path `{path}` is invalid type") # convert the relative paths path = os.path.realpath(path) @@ -63,6 +63,7 @@ class _ThreadLocalInfo(threading.local): """ Thread local Info used for store thread local attributes. """ + def __init__(self): super(_ThreadLocalInfo, self).__init__() self._reserve_class_name_in_scope = True @@ -90,6 +91,7 @@ class _ContextSwitchInfo(threading.local): Args: is_pynative (bool): Whether to adopt the PyNative mode. """ + def __init__(self, is_pynative): super(_ContextSwitchInfo, self).__init__() self.context_stack = [] @@ -209,7 +211,7 @@ class _Context: def device_target(self, target): success = self._context_handle.set_device_target(target) if not success: - raise ValueError("target device name is invalid!!!") + raise ValueError("Target device name is invalid!!!") @property def device_id(self): @@ -335,7 +337,7 @@ class _Context: @graph_memory_max_size.setter def graph_memory_max_size(self, graph_memory_max_size): - if check_input_fotmat(graph_memory_max_size): + if check_input_format(graph_memory_max_size): graph_memory_max_size_ = graph_memory_max_size[:-2] + " * 1024 * 1024 * 1024" self._context_handle.set_graph_memory_max_size(graph_memory_max_size_) else: @@ -347,7 +349,7 @@ class _Context: @variable_memory_max_size.setter def variable_memory_max_size(self, variable_memory_max_size): - if check_input_fotmat(variable_memory_max_size): + if check_input_format(variable_memory_max_size): variable_memory_max_size_ = variable_memory_max_size[:-2] + " * 1024 * 1024 * 1024" self._context_handle.set_variable_memory_max_size(variable_memory_max_size_) else: @@ -367,12 +369,13 @@ class _Context: thread_info.debug_runtime = enable -def check_input_fotmat(x): +def check_input_format(x): import re pattern = r'[1-9][0-9]*(\.)?[0-9]*GB|0\.[0-9]*GB' result = re.match(pattern, x) return result is not None + _k_context = None diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index c99ac4a3c7..bf4b99085e 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -17,7 +17,7 @@ import threading import mindspore.context as context from mindspore.parallel._dp_allreduce_fusion import _set_fusion_strategy_by_idx, _set_fusion_strategy_by_size from mindspore._c_expression import AutoParallelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check class _AutoParallelContext: diff --git a/mindspore/parallel/_cost_model_context.py b/mindspore/parallel/_cost_model_context.py index 0920d66f41..54cca5516b 100644 --- a/mindspore/parallel/_cost_model_context.py +++ b/mindspore/parallel/_cost_model_context.py @@ -15,7 +15,7 @@ """Context of cost_model in auto_parallel""" import threading from mindspore._c_expression import CostModelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check class _CostModelContext: diff --git a/mindspore/parallel/algo_parameter_config.py b/mindspore/parallel/algo_parameter_config.py index d1e4aa87a9..244156da33 100644 --- a/mindspore/parallel/algo_parameter_config.py +++ b/mindspore/parallel/algo_parameter_config.py @@ -16,7 +16,7 @@ import threading from mindspore._c_expression import CostModelContext -from mindspore._extends.pynative_helper import args_type_check +from mindspore._checkparam import args_type_check __all__ = ["get_algo_parameters", "reset_algo_parameters", "set_algo_parameters"] diff --git a/tests/ut/python/pynative_mode/test_backend.py b/tests/ut/python/pynative_mode/test_backend.py index 7258b69486..fae1974854 100644 --- a/tests/ut/python/pynative_mode/test_backend.py +++ b/tests/ut/python/pynative_mode/test_backend.py @@ -14,16 +14,13 @@ # ============================================================================ """ test_backend """ import os -import numpy as np import pytest from mindspore.ops import operations as P import mindspore.nn as nn -from mindspore import context +from mindspore import context, ms_function from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter -from mindspore._extends.pynative_helper import args_type_check -from mindspore.common.tensor import Tensor -from mindspore.common.api import ms_function +from mindspore._checkparam import args_type_check def setup_module(module): @@ -32,6 +29,7 @@ def setup_module(module): class Net(nn.Cell): """ Net definition """ + def __init__(self): super(Net, self).__init__() self.add = P.TensorAdd() @@ -50,6 +48,7 @@ def test_vm_backend(): output = add() assert output.asnumpy().shape == (1, 3, 3, 4) + def test_vm_set_context(): """ test_vm_set_context """ context.set_context(save_graphs=True, save_graphs_path="mindspore_ir_path", mode=context.GRAPH_MODE) @@ -59,6 +58,7 @@ def test_vm_set_context(): assert context.get_context("save_graphs_path").find("mindspore_ir_path") > 0 context.set_context(mode=context.PYNATIVE_MODE) + @args_type_check(v_str=str, v_int=int, v_tuple=tuple) def check_input(v_str, v_int, v_tuple): """ check_input """ From 399d72874be3379f9b7d707ef1b1cb8b1ef2a14c Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 21 Apr 2020 10:37:08 +0800 Subject: [PATCH 031/142] fix visit kernel missing the return_types --- mindspore/ccsrc/session/anf_runtime_algorithm.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 2591f763c5..525ff44dd8 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -111,12 +111,12 @@ KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr MS_EXCEPTION_IF_NULL(value_node); int item_idx = GetValue(value_node->value()); return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx), - visit_nop_node); + visit_nop_node, return_types); } else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) { - return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node); + return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types); } else if (opt::IsNopNode(cnode) && visit_nop_node) { if (cnode->inputs().size() == 2) { - return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node); + return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types); } else { MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node"; } From 897ec89d44f6855968be444887bd7bfea9090ecb Mon Sep 17 00:00:00 2001 From: zjun Date: Tue, 21 Apr 2020 10:39:02 +0800 Subject: [PATCH 032/142] fix aicpu set attr bug --- mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc index 808e87edc0..d6217ff1cc 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_kernel_build.cc @@ -162,18 +162,17 @@ void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *p ::google::protobuf::Map<::std::string, ::mindspore::AttrValue> *node_attr = proto->mutable_attrs(); for (const auto &attr_ptr : attrs_ptr) { std::string attr_name = attr_ptr->name(); - std::string real_name; auto value = primitive->GetAttr(attr_name); if (value != nullptr) { if (attr_name == kQueueName || attr_name == kSharedName) { - real_name = kChannelName; + attr_name = kChannelName; } else if (attr_name == kSeed) { - real_name = "seed"; + attr_name = "seed"; } else if (attr_name == kSeed2) { - real_name = "seed2"; + attr_name = "seed2"; } std::string type = attr_ptr->type(); - ParseAttrValue(type, real_name, value, node_attr); + ParseAttrValue(type, attr_name, value, node_attr); } } MS_LOG(INFO) << "Set node attr end!"; @@ -182,7 +181,7 @@ void SetNodeAttr(const std::shared_ptr &anf_node, mindspore::NodeDef *p void SetNodeInputs(const std::shared_ptr &anf_node, mindspore::NodeDef *proto) { size_t input_num = AnfAlgo::GetInputTensorNum(anf_node); if (input_num == 0) { - MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input. "; + MS_LOG(INFO) << "Node [" << AnfAlgo::GetCNodeName(anf_node) << "] does not have input."; return; } From ab7e00589d9f84e080c3c4ddc4f874411ff54a4b Mon Sep 17 00:00:00 2001 From: zjun Date: Tue, 21 Apr 2020 10:55:00 +0800 Subject: [PATCH 033/142] Add aicpu support ms_type --- mindspore/ccsrc/kernel/aicpu/aicpu_util.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc b/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc index 316df63922..a617f56f8f 100644 --- a/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc +++ b/mindspore/ccsrc/kernel/aicpu/aicpu_util.cc @@ -27,6 +27,7 @@ namespace kernel { static std::map MS_PROTO_DATA_TYPE_MAP = { {mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN}, {mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL}, + {mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32}, {mindspore::TypeId::kNumberTypeInt8, mindspore::DataType::MS_INT8}, {mindspore::TypeId::kNumberTypeInt16, mindspore::DataType::MS_INT16}, {mindspore::TypeId::kNumberTypeInt32, mindspore::DataType::MS_INT32}, @@ -34,8 +35,10 @@ static std::map MS_PROTO_DATA_TYPE_MAP = { {mindspore::TypeId::kNumberTypeUInt, mindspore::DataType::MS_UINT32}, {mindspore::TypeId::kNumberTypeUInt8, mindspore::DataType::MS_UINT8}, {mindspore::TypeId::kNumberTypeUInt16, mindspore::DataType::MS_UINT16}, + {mindspore::TypeId::kNumberTypeUInt32, mindspore::DataType::MS_UINT32}, {mindspore::TypeId::kNumberTypeUInt64, mindspore::DataType::MS_UINT64}, {mindspore::TypeId::kNumberTypeFloat16, mindspore::DataType::MS_FLOAT16}, + {mindspore::TypeId::kNumberTypeFloat, mindspore::DataType::MS_FLOAT32}, {mindspore::TypeId::kNumberTypeFloat32, mindspore::DataType::MS_FLOAT32}, {mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64}, }; From 252ed4f7c99e02cba8622cfcec674ebf648e581f Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Tue, 21 Apr 2020 11:16:44 +0800 Subject: [PATCH 034/142] use the old op --- .../ccsrc/optimizer/irpass/arithmetic_simplify.h | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index 0d48fc1463..ff6e4f6170 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -248,17 +248,18 @@ class AdjustAllReduceMulAdd : public AnfVisitor { if (addn->size() != 2) { return nullptr; } - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); if (x_ == nullptr || y_ == nullptr || z_ == nullptr) { return nullptr; } + auto addn_op_node = addn->input(0); + auto make_tuple_op_node = addn->input(1)->cast()->input(0); auto fg = node->func_graph(); - AnfNodePtr tuple = NewCNode({NewValueNode(prim::kPrimMakeTuple), z_, x_}, fg); - AnfNodePtr add = NewCNode({NewValueNode(prim::kPrimAddN), tuple}, fg); - AnfNodePtr all_reduce = NewCNode({NewValueNode(prim::kPrimAllReduce), add}, fg); - return NewCNode({NewValueNode(prim::kPrimMul), all_reduce, y_}, fg); + AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); + AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); + AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); + return NewCNode({mul_, all_reduce, y_}, fg); } void Visit(const AnfNodePtr &node) override { @@ -269,6 +270,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { AnfVisitor::Match(prim::kPrimMul)(node); level_ = 0; if (is_reduce_match_) { + mul_ = node->cast()->input(0); y_ = tmp_; } else { z_ = node; @@ -280,6 +282,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { auto cnode = node->cast(); if (cnode->size() > 1) { + all_reduce_ = cnode->input(0); x_ = cnode->input(1); is_reduce_match_ = true; } @@ -302,6 +305,7 @@ class AdjustAllReduceMulAdd : public AnfVisitor { int level_{0}; bool is_reduce_match_{false}; AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; + AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}; }; class ArithmeticSimplify { From f8208c7c522cf6e303b8aec2dff16a8a748ca90b Mon Sep 17 00:00:00 2001 From: gukecai Date: Thu, 16 Apr 2020 15:46:09 +0800 Subject: [PATCH 035/142] Support GetNext Parallel --- .../device/ascend/ascend_kernel_runtime.cc | 9 +- .../device/ascend/ascend_stream_assign.cc | 154 +++++++--- .../device/ascend/ascend_stream_assign.h | 33 +- mindspore/ccsrc/device/kernel_adjust.cc | 286 +++++++++--------- mindspore/ccsrc/device/kernel_adjust.h | 25 +- .../ascend/ascend_backend_optimization.cc | 7 + .../pre_activate/mem_reuse/stream_reuse.cc | 4 +- mindspore/ccsrc/session/ascend_session.cc | 2 +- mindspore/ccsrc/utils/utils.h | 3 + .../tasksink/ascend_stream_assign_stub.cc | 4 +- 10 files changed, 304 insertions(+), 223 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc index 935e694636..44cf3f8fa8 100644 --- a/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc @@ -283,18 +283,19 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) { AscendStreamAssign &assign_instance = AscendStreamAssign::GetInstance(); // the streams' flag not HEAD_STREAM - std::vector wait_active_stream_list = assign_instance.GetWaitStreams(); - std::vector force_copy_stream_list = assign_instance.GetHcomStreams(); + std::vector wait_active_stream_list; + assign_instance.GetWaitStreams(&wait_active_stream_list); + auto force_copy_stream_list = assign_instance.hcom_streams(); MS_LOG(INFO) << "call DavinciModel total stream num:" << assign_instance.GetTotalStreamNum() - << ", total event num:" << assign_instance.GetTotalEventNum() + << ", total event num:" << assign_instance.total_event_num() << ", wait_active_stream_list size:" << wait_active_stream_list.size() << ", force_copy_stream_list size:" << force_copy_stream_list.size(); std::vector> empty_list; std::shared_ptr model = std::make_shared( task_info_list, empty_list, empty_list, empty_list, empty_list, wait_active_stream_list, force_copy_stream_list, 0, - 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.GetTotalEventNum(), 0); + 0, 0, 0, 0, 0, assign_instance.GetTotalStreamNum(), 1, assign_instance.total_event_num(), 0); auto ret = graph_model_map_.insert(std::make_pair(graph->graph_id(), model)); if (!ret.second) { diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc index 8c4d1f4a8f..e2cf469cd8 100644 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.cc @@ -25,8 +25,8 @@ #include "session/anf_runtime_algorithm.h" #include "device/kernel_adjust.h" #include "predict/generator/utils/ir_model_util.h" -#include "device/kernel_info.h" #include "pre_activate/common/helper.h" +#include "utils/utils.h" namespace mindspore { namespace device { @@ -54,6 +54,7 @@ void AscendStreamAssign::ResetNew() { inner_parallel_streams_.clear(); processed_parallel_streams_.clear(); hcom_stream_list_.clear(); + need_first_active_streams_.clear(); } void AscendStreamAssign::AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t processing_logic_id) { @@ -200,13 +201,12 @@ void AscendStreamAssign::AssignAllNodesStream(const shared_ptr AscendStreamAssign::TransLogicToPhysic(const vector &logic_ids) { - vector physic_ids; +void AscendStreamAssign::TransLogicToPhysic(const vector &logic_ids, vector *physic_ids) { for (auto &id : logic_ids) { auto it = logic_to_physic_map_.find(id); if (it != logic_to_physic_map_.end()) { MS_LOG(INFO) << "logic id[" << id << "] to physic id[" << it->second << "]"; - physic_ids.push_back(it->second); + (*physic_ids).push_back(it->second); } else { MS_LOG(EXCEPTION) << "logic id[" << id << "] has no correspond physic id"; } @@ -214,10 +214,9 @@ vector AscendStreamAssign::TransLogicToPhysic(const vector & auto it_independ = logic_to_independent_map_.find(id); if (it_independ != logic_to_independent_map_.end()) { MS_LOG(INFO) << "logic id[" << id << "] to independent id[" << it_independ->second << "]"; - physic_ids.push_back(it_independ->second); + (*physic_ids).push_back(it_independ->second); } } - return physic_ids; } void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { @@ -227,7 +226,8 @@ void AscendStreamAssign::UpdateStreamActive(const CNodePtr &active_ptr) { MS_EXCEPTION_IF_NULL(primitive); vector active_logic_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); // out StreamAcitve active physic stream is not parallel now, if parallel, should deal here. - vector active_physic_ids = TransLogicToPhysic(active_logic_ids); + vector active_physic_ids; + TransLogicToPhysic(active_logic_ids, &active_physic_ids); ValuePtr active_physic_value = MakeValue>(active_physic_ids); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, active_physic_value, active_ptr); } @@ -242,7 +242,8 @@ void AscendStreamAssign::UpdateStreamSwitch(const CNodePtr &switch_ptr, const CN MS_LOG(INFO) << "streamswtich stream id[" << AnfAlgo::GetStreamId(switch_ptr) << "], true_logic_id[" << true_logic_id << "]"; vector logic_ids{true_logic_id}; - vector physic_ids = TransLogicToPhysic(logic_ids); + vector physic_ids; + TransLogicToPhysic(logic_ids, &physic_ids); if (physic_ids.empty()) { MS_LOG(EXCEPTION) << "stream switch true logic id[" << true_logic_id << "] has no physical id"; } @@ -334,8 +335,8 @@ bool AscendStreamAssign::IsProcessedParallelStream(uint32_t stream_id) { return false; } -vector AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id) { - vector parallel_streams; +void AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, + vector *parallel_streams) { for (size_t i = 0; i < inner_parallel_streams_.size(); i++) { auto cur_parallel_streams = inner_parallel_streams_[i]; auto it = std::find(cur_parallel_streams.begin(), cur_parallel_streams.end(), cur_stream_id); @@ -347,17 +348,17 @@ vector AscendStreamAssign::GetParallelStream(uint32_t cur_stream_id, u << "is same with streamacvite stream id" << stream_acitve_id; continue; } - parallel_streams.emplace_back(cur_parallel_streams[j]); + (*parallel_streams).emplace_back(cur_parallel_streams[j]); } // record processed parallel streams - (void)std::copy(parallel_streams.begin(), parallel_streams.end(), + (void)std::copy((*parallel_streams).begin(), (*parallel_streams).end(), std::back_inserter(processed_parallel_streams_)); - return parallel_streams; + return; } } - return vector{cur_stream_id}; + (*parallel_streams).push_back(cur_stream_id); } void AscendStreamAssign::InsertActiveNew(const std::shared_ptr &graph_ptr) { @@ -379,30 +380,32 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr active_index_list = GetParallelStream(cur_stream_id, pre_stream_id); + std::vector active_index_list; + GetParallelStream(cur_stream_id, pre_stream_id, &active_index_list); AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(active_index_list), active_ptr); - } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive" && - AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { + } + // inner_active is not a if/else relationship with the next if/else. such as:StreamActive(S7)-->StreamActive(S8) + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName && + AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()) != UINT32_MAX) { // 2)outter stream assign, update active op update_cnode_list.emplace_back(cur_cnode_ptr); UpdateStreamActive(cur_cnode_ptr); - } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamSwitch") { + } else if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { // 3)update switch op MS_LOG(INFO) << "Insert active op after switch"; - CNodePtr active_ptr = KernelAdjust::GetInstance().CreateSteamActiveOp(graph_ptr); + CNodePtr active_ptr = KernelAdjust::GetInstance().CreateStreamActiveOp(graph_ptr); update_cnode_list.emplace_back(cur_cnode_ptr); update_cnode_list.emplace_back(active_ptr); UpdateStreamSwitch(cur_cnode_ptr, active_ptr); @@ -417,6 +420,37 @@ void AscendStreamAssign::InsertActiveNew(const std::shared_ptr &graph_ptr) { + MS_LOG(INFO) << "start"; + MS_EXCEPTION_IF_NULL(graph_ptr); + CNodePtr cur_cnode_ptr = nullptr; + // key:virutal event id, value:real event id + std::unordered_map event_id_map; + uint32_t event_id; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kSendOpName || AnfAlgo::GetCNodeName(cur_cnode_ptr) == kRecvOpName) { + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(primitive); + event_id = GetValue(primitive->GetAttr(kAttrEventId)); + // before stream assign, send/recv event_id assign from kFirstEventId + if (event_id < kFirstEventId) { + continue; + } + auto it = event_id_map.find(event_id); + if (it == event_id_map.end()) { + event_id_map.insert(std::make_pair(event_id, total_event_num_)); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(total_event_num_), cur_cnode_ptr); + total_event_num_++; + } else { + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(it->second), cur_cnode_ptr); + } + } + } +} + void AscendStreamAssign::UpdateStreamId(const shared_ptr &graph_ptr) { MS_LOG(INFO) << "start"; MS_EXCEPTION_IF_NULL(graph_ptr); @@ -427,7 +461,7 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr & MS_EXCEPTION_IF_NULL(cur_cnode_ptr); uint32_t cur_stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); if (cur_stream_id < kIndependFirstStreamId) { - if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == "StreamActive") { + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamActiveOpName) { auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); MS_EXCEPTION_IF_NULL(primitive); vector active_ids = GetValue>(primitive->GetAttr(kAttrActiveStreamList)); @@ -471,6 +505,29 @@ void AscendStreamAssign::UpdateStreamId(const shared_ptr & MS_LOG(INFO) << "end"; } +void AscendStreamAssign::GetNeedActiveStreams(const shared_ptr &graph_ptr) { + MS_EXCEPTION_IF_NULL(graph_ptr); + CNodePtr cur_cnode_ptr = nullptr; + auto cnode_ptr_list = graph_ptr->execution_order(); + for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + auto primitive = AnfAlgo::GetCNodePrimitive(cur_cnode_ptr); + MS_EXCEPTION_IF_NULL(primitive); + auto value_ptr = primitive->GetAttr(kStreamNeedActivedFirst); + if (value_ptr == nullptr) { + continue; + } + + auto need_active = GetValue(value_ptr); + if (need_active) { + auto stream_id = AnfAlgo::GetStreamId(cur_cnode_ptr); + MS_LOG(INFO) << "stream id:" << stream_id << " is need actived at first"; + need_first_active_streams_.push_back(stream_id); + } + } +} + void AscendStreamAssign::AssignStreamNew(const shared_ptr &graph_ptr) { if (IsTaskSink()) { ResetNew(); @@ -480,13 +537,15 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr InsertSendRecvForHcomParallel(graph_ptr); InsertSendRecvForIndependent(graph_ptr); UpdateStreamId(graph_ptr); + UpdateEventId(graph_ptr); + GetNeedActiveStreams(graph_ptr); MS_LOG(INFO) << "after finish stream assign"; PrintGraphExeOrders(graph_ptr); // Get info for D Model - generator::IRModelUtil::GetInstance().set_event_num(GetTotalEventNum()); - generator::IRModelUtil::GetInstance().set_stream_num(GetTotalCommonStreamNum() + GetTotalIndependStreamNum()); + generator::IRModelUtil::GetInstance().set_event_num(total_event_num()); + generator::IRModelUtil::GetInstance().set_stream_num(total_common_stream_num() + total_independ_stream_num()); // Init to 1,temporarily generator::IRModelUtil::GetInstance().set_batch_num(1); } @@ -495,7 +554,7 @@ void AscendStreamAssign::AssignStreamNew(const shared_ptr CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id) { MS_EXCEPTION_IF_NULL(graph_ptr); - auto send_op = std::make_shared("Send"); + auto send_op = std::make_shared(kSendOpName); MS_EXCEPTION_IF_NULL(send_op); auto send_apply = std::make_shared(send_op); MS_EXCEPTION_IF_NULL(send_apply); @@ -505,7 +564,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr(); MS_EXCEPTION_IF_NULL(abstract_none); send_node_ptr->set_abstract(abstract_none); @@ -516,7 +575,7 @@ CNodePtr AscendStreamAssign::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id) { MS_EXCEPTION_IF_NULL(graph_ptr); - auto recv_op = std::make_shared("Recv"); + auto recv_op = std::make_shared(kRecvOpName); MS_EXCEPTION_IF_NULL(recv_op); auto recv_apply = std::make_shared(recv_op); MS_EXCEPTION_IF_NULL(recv_apply); @@ -526,7 +585,7 @@ CNodePtr AscendStreamAssign::CreateRecvApplyKernel(const std::shared_ptr(); MS_EXCEPTION_IF_NULL(abstract_none); @@ -605,7 +664,7 @@ bool AscendStreamAssign::IsIndependentNode(const CNodePtr &node_ptr) { return false; } - if (AnfAlgo::GetCNodeName(node_ptr) == "GetNext") { + if (AnfAlgo::GetCNodeName(node_ptr) == kGetNextOpName) { MS_LOG(INFO) << "GetNext should not be independent node"; return false; } @@ -638,20 +697,23 @@ bool AscendStreamAssign::IsTaskSink() { } } -std::vector AscendStreamAssign::GetWaitStreams() { - vector wait_active_stream_list; +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { if (total_common_stream_num_ == 0) { MS_LOG(INFO) << "total_common_stream_num is zero"; - return wait_active_stream_list; + return; } // common stream:active first common stream MS_LOG(INFO) << "active physic id[" << first_physic_id_ << "]"; for (uint32_t i = first_physic_id_ + 1; i < total_common_stream_num_; i++) { - MS_LOG(INFO) << "wait common stream id = " << i; - wait_active_stream_list.push_back(i); + auto it = std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i); + if (it == need_first_active_streams_.end()) { + MS_LOG(INFO) << "wait common stream id = " << i; + (*wait_active_stream_list).push_back(i); + } } + // all independ stream id before first physical stream id should be actived auto it = logic_to_independent_map_.find(first_logic_id_); if (it != logic_to_independent_map_.end()) { uint32_t independent_id = it->second; @@ -675,16 +737,14 @@ std::vector AscendStreamAssign::GetWaitStreams() { if (i + total_common_stream_num_ <= max_before_physic) { continue; } - MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; - wait_active_stream_list.push_back(i + total_common_stream_num_); + // all wait streams should not in need_first_active_streams_ + auto iter = + std::find(need_first_active_streams_.begin(), need_first_active_streams_.end(), i + total_common_stream_num_); + if (iter == need_first_active_streams_.end()) { + MS_LOG(INFO) << "wait independent stream id:" << i + total_common_stream_num_; + (*wait_active_stream_list).push_back(i + total_common_stream_num_); + } } - - return wait_active_stream_list; -} - -std::vector AscendStreamAssign::GetHcomStreams() { - MS_LOG(INFO) << "hcom total stream nums:" << hcom_stream_list_.size(); - return hcom_stream_list_; } uint32_t AscendStreamAssign::GetTotalStreamNum() const { return total_common_stream_num_ + total_independ_stream_num_; } @@ -695,7 +755,7 @@ void AscendStreamAssign::PrintGraphExeOrders(const shared_ptr& graph_ptr); void AssignAllNodesStream(const std::shared_ptr& graph_ptr); void ResetNew(); void AssignStreamNew(const std::shared_ptr& graph_ptr); bool IsIndependentNode(const CNodePtr& node_ptr); - const std::unordered_map GetIndependentMap() { return logic_to_independent_map_; } - const std::unordered_map GetPhysicMap() { return logic_to_physic_map_; } - std::vector GetWaitStreams(); - std::vector GetHcomStreams(); - - private: - AscendStreamAssign() = default; - ~AscendStreamAssign() = default; - + const std::unordered_map& logic_to_independent_map() { return logic_to_independent_map_; } + const std::unordered_map& logic_to_physic_map() { return logic_to_physic_map_; } + const std::vector>& inner_parallel_streams() { return inner_parallel_streams_; } + void GetWaitStreams(vector* wait_active_stream_list); + const std::vector& hcom_streams() { return hcom_stream_list_; } CNodePtr CreateSendApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, uint32_t stream_id); CNodePtr CreateRecvApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, uint32_t stream_id); + private: + AscendStreamAssign() = default; + ~AscendStreamAssign() = default; + vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, const CNodePtr& node); bool IsHcom(const CNodePtr& apply_kernel); bool IsProcessed(uint32_t logic_id); - vector TransLogicToPhysic(const vector& logic_ids); + void TransLogicToPhysic(const vector& logic_ids, vector* physic_ids); void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, uint32_t* cur_stream_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id); @@ -88,15 +86,17 @@ class AscendStreamAssign { bool IsTaskSink(); void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); void UpdateStreamId(const std::shared_ptr& graph_ptr); + void UpdateEventId(const std::shared_ptr& graph_ptr); void PrintGraphExeOrders(const std::shared_ptr& graph_ptr); void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); void SetCommonStreamNum(uint32_t cur_stream_id); void FindAllReduceParallel(const std::shared_ptr& graph_ptr); bool IsProcessedParallelStream(uint32_t stream_id); - vector GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector* parallel_streams); void InsertSendRecvForIndependent(const std::shared_ptr& graph_ptr); void InsertSendRecvForHcomParallel(const std::shared_ptr& graph_ptr); + void GetNeedActiveStreams(const std::shared_ptr& graph_ptr); uint32_t total_common_stream_num_{0}; uint32_t total_independ_stream_num_{0}; @@ -112,6 +112,7 @@ class AscendStreamAssign { std::vector> inner_parallel_streams_{}; std::vector processed_parallel_streams_{}; std::vector hcom_stream_list_{}; + std::vector need_first_active_streams_{}; // new policy end }; } // namespace ascend diff --git a/mindspore/ccsrc/device/kernel_adjust.cc b/mindspore/ccsrc/device/kernel_adjust.cc index c1588d7d53..b557436db9 100644 --- a/mindspore/ccsrc/device/kernel_adjust.cc +++ b/mindspore/ccsrc/device/kernel_adjust.cc @@ -32,16 +32,8 @@ #include "utils/utils.h" #include "device/ascend/profiling/profiling_manager.h" #include "device/ascend/kernel_select_ascend.h" -#include "device/kernel_info.h" #include "runtime/base.h" - -constexpr auto kLoopCountParamName = "loop_count"; -constexpr auto kIterLoopParamName = "iter_loop"; -constexpr auto kZeroParamName = "zero"; -constexpr auto kOneParamName = "one"; -constexpr auto kStreamSwitch = "StreamSwitch"; -constexpr auto kStreamActive = "StreamActive"; -constexpr auto kAssignAdd = "AssignAdd"; +#include "device/ascend/ascend_stream_assign.h" namespace mindspore { namespace device { using device::ascend::ProfilingUtils; @@ -70,6 +62,63 @@ bool KernelAdjust::NeedInsertSwitch() { ConfigManager::GetInstance().iter_num() > 1); } +uint32_t KernelAdjust::FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr) { + MS_EXCEPTION_IF_NULL(kernel_graph_ptr); + auto cnode_ptr_list = kernel_graph_ptr->execution_order(); + CNodePtr cur_cnode_ptr = nullptr; + uint32_t label = kInvalidDistincLabel; + for (uint32_t i = 0; i < cnode_ptr_list.size(); ++i) { + cur_cnode_ptr = cnode_ptr_list[i]; + MS_EXCEPTION_IF_NULL(cur_cnode_ptr); + if (AnfAlgo::GetCNodeName(cur_cnode_ptr) == kStreamSwitchOpName) { + label = AnfAlgo::GetStreamDistinctionLabel(cur_cnode_ptr.get()); + break; + } + } + + return label; +} + +CNodePtr KernelAdjust::CreateSendApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto send_op = std::make_shared(kSendOpName); + MS_EXCEPTION_IF_NULL(send_op); + auto send_apply = std::make_shared(send_op); + MS_EXCEPTION_IF_NULL(send_apply); + std::vector send_input_list = {send_apply}; + CNodePtr send_node_ptr = graph_ptr->NewCNode(send_input_list); + MS_EXCEPTION_IF_NULL(send_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), send_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), send_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + send_node_ptr->set_abstract(abstract_none); + return send_node_ptr; +} + +CNodePtr KernelAdjust::CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, + uint32_t event_id) { + MS_EXCEPTION_IF_NULL(graph_ptr); + auto recv_op = std::make_shared(kRecvOpName); + MS_EXCEPTION_IF_NULL(recv_op); + auto recv_apply = std::make_shared(recv_op); + MS_EXCEPTION_IF_NULL(recv_apply); + std::vector recv_input_list = {recv_apply}; + CNodePtr recv_node_ptr = graph_ptr->NewCNode(recv_input_list); + MS_EXCEPTION_IF_NULL(recv_node_ptr); + kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; + selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL); + AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), recv_node_ptr.get()); + AnfAlgo::SetNodeAttr(kAttrEventId, MakeValue(event_id), recv_node_ptr); + auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); + recv_node_ptr->set_abstract(abstract_none); + return recv_node_ptr; +} + void KernelAdjust::InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr) { if (!NeedInsertSwitch()) { return; @@ -93,21 +142,95 @@ void KernelAdjust::InsertSwitchLoop(const std::shared_ptr } } } + + auto orders = kernel_graph_ptr->execution_order(); + if (orders.empty()) { + MS_LOG(EXCEPTION) << "graph execution order is empty"; + } + uint32_t first_cnode_stream_label = AnfAlgo::GetStreamDistinctionLabel(orders[0].get()); + std::vector exec_order; - CNodePtr stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); - MS_EXCEPTION_IF_NULL(stream_switch_app); - exec_order.push_back(stream_switch_app); + CNodePtr first_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(first_stream_switch_app); + AnfAlgo::SetStreamDistinctionLabel(kFirstStreamSwitchLabel, first_stream_switch_app.get()); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(kGetNextLabel), first_stream_switch_app); + + CNodePtr second_stream_switch_app = CreateStreamSwitchOp(kernel_graph_ptr, switch_loop_input); + MS_EXCEPTION_IF_NULL(second_stream_switch_app); + AnfAlgo::SetStreamDistinctionLabel(kSecondStreamSwitchLabel, second_stream_switch_app.get()); + AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_cnode_stream_label), second_stream_switch_app); + // add attr "stream_need_active" + AnfAlgo::SetNodeAttr(kStreamNeedActivedFirst, MakeValue(true), second_stream_switch_app); + + CNodePtr first_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(first_stream_active_app); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, first_stream_active_app.get()); + std::vector first_active_streams = {kFirstStreamSwitchLabel}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(first_active_streams), + first_stream_active_app); + + CNodePtr second_stream_active_app = CreateStreamActiveOp(kernel_graph_ptr); + MS_EXCEPTION_IF_NULL(second_stream_active_app); + // specific deal for common ctrl stream policy + uint32_t first_common_stream_switch_label = FindFirstStreamSwitchLabel(kernel_graph_ptr); + if (first_common_stream_switch_label == kInvalidDistincLabel) { + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, second_stream_active_app.get()); + } else { + AnfAlgo::SetStreamDistinctionLabel(first_common_stream_switch_label, second_stream_active_app.get()); + } - CNodePtr stream_active_switch_app = CreateStreamActiveSwitchOp(kernel_graph_ptr); - MS_EXCEPTION_IF_NULL(stream_active_switch_app); + std::vector second_active_streams = {kSecondStreamSwitchLabel}; + AnfAlgo::SetNodeAttr(kAttrActiveStreamList, MakeValue>(second_active_streams), + second_stream_active_app); CNodePtr assign_add_one = CreateStreamAssignAddnOP(kernel_graph_ptr, switch_loop_input); MS_EXCEPTION_IF_NULL(assign_add_one); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, assign_add_one.get()); + + CNodePtr send = CreateSendApplyKernel(kernel_graph_ptr, kFirstEventId); + AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, send.get()); + CNodePtr recv = CreateRecvApplyKernel(kernel_graph_ptr, kFirstEventId); + AnfAlgo::SetStreamDistinctionLabel(first_cnode_stream_label, recv.get()); + + // reorder graph orders + exec_order.push_back(first_stream_switch_app); + size_t i = 0; + for (; i < orders.size(); i++) { + auto node = orders[i]; + exec_order.push_back(node); + AnfAlgo::SetStreamDistinctionLabel(kGetNextLabel, exec_order[exec_order.size() - 1].get()); + if (AnfAlgo::GetCNodeName(node) == kGetNextOpName) { + break; + } + } + + exec_order.push_back(send); + exec_order.push_back(second_stream_switch_app); + exec_order.push_back(recv); exec_order.push_back(assign_add_one); - auto original_exec_order = kernel_graph_ptr->execution_order(); - (void)std::copy(original_exec_order.begin(), original_exec_order.end(), std::back_inserter(exec_order)); - exec_order.push_back(stream_active_switch_app); + std::vector memcpy_list; + std::vector before_list; + std::vector after_list; + bool first_memcpy_found = false; + CNodePtr cur_cnode = nullptr; + for (size_t idx = i + 1; idx < orders.size(); idx++) { + cur_cnode = orders[idx]; + if (AnfAlgo::HasNodeAttr(kAttrLabelForInsertStreamActive, cur_cnode)) { + memcpy_list.emplace_back(cur_cnode); + first_memcpy_found = true; + } else if (first_memcpy_found) { + after_list.emplace_back(cur_cnode); + } else { + before_list.emplace_back(cur_cnode); + } + } + + (void)std::copy(before_list.begin(), before_list.end(), std::back_inserter(exec_order)); + (void)std::copy(memcpy_list.begin(), memcpy_list.end(), std::back_inserter(exec_order)); + exec_order.push_back(first_stream_active_app); + (void)std::copy(after_list.begin(), after_list.end(), std::back_inserter(exec_order)); + exec_order.push_back(second_stream_active_app); kernel_graph_ptr->set_execution_order(exec_order); } @@ -167,7 +290,7 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr(); - auto stream_switch = std::make_shared(kStreamSwitch); + auto stream_switch = std::make_shared(kStreamSwitchOpName); std::vector inputs; inputs.push_back(NewValueNode(stream_switch)); inputs.push_back(switch_loop_input.at(kLoopCountParamName)); @@ -181,28 +304,19 @@ CNodePtr KernelAdjust::CreateStreamSwitchOp(const std::shared_ptr(RT_LESS); ValuePtr cond = MakeValue(condition); AnfAlgo::SetNodeAttr(kAttrSwitchCondition, cond, stream_switch_app); - // set attr:true branch graph id ,which is same to stream distinction label - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto first_stream = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, MakeValue(first_stream), stream_switch_app); // set attr:data_type int data_type = static_cast(RT_SWITCH_INT64); ValuePtr dt = MakeValue(data_type); AnfAlgo::SetNodeAttr(kAttrDataType, dt, stream_switch_app); // set distinction label and graph id - AnfAlgo::SetGraphId(kInvalidGraphId - 1, stream_switch_app.get()); - AnfAlgo::SetStreamDistinctionLabel(kInvalidDistincLabel - 1, stream_switch_app.get()); return stream_switch_app; } -CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr) { +CNodePtr KernelAdjust::CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr) { kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActive); + auto stream_active_others = std::make_shared(kStreamActiveOpName); std::vector inputs; inputs.push_back(NewValueNode(stream_active_others)); MS_EXCEPTION_IF_NULL(kernel_graph_ptr); @@ -213,57 +327,6 @@ CNodePtr KernelAdjust::CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_switch = std::make_shared(kStreamActive); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_switch)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_switch_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_switch_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_switch_app.get()); - stream_active_switch_app->set_abstract(typeNone_abstract); - // set attr,which stream to active - std::vector active_index_value = {kInvalidDistincLabel - 1}; - auto value = MakeValue>(active_index_value); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, value, stream_active_switch_app); - // set the distinction label of stream active - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - // find the first switch's distinction label - for (auto node : kernel_graph_ptr->execution_order()) { - if (AnfAlgo::GetCNodeName(node) == "StreamSwitch") { - label = AnfAlgo::GetStreamDistinctionLabel(node.get()); - break; - } - } - AnfAlgo::SetStreamDistinctionLabel(label, stream_active_switch_app.get()); - return stream_active_switch_app; -} - -CNodePtr KernelAdjust::CreateStreamActiveOtherOp(const std::shared_ptr &kernel_graph_ptr) { - kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder = CreateMngKernelBuilder( - {kOpFormat_DEFAULT, kOpFormat_DEFAULT}, {TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32}); - abstract::AbstractBasePtr typeNone_abstract = std::make_shared(); - auto stream_active_others = std::make_shared(kStreamActive); - std::vector inputs; - inputs.push_back(NewValueNode(stream_active_others)); - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr stream_active_others_app = kernel_graph_ptr->NewCNode(inputs); - MS_EXCEPTION_IF_NULL(stream_active_others_app); - AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), stream_active_others_app.get()); - stream_active_others_app->set_abstract(typeNone_abstract); - // set attr - ValuePtr active_target = MakeValue(kValueTargetOther); - AnfAlgo::SetNodeAttr(kAttrActiveTarget, active_target, stream_active_others_app); - return stream_active_others_app; -} - CNodePtr KernelAdjust::CreateStreamAssignAddnOP( const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input) { @@ -273,7 +336,7 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT}); selected_kernel_builder.SetOutputsDeviceType({kNumberTypeInt32}); // AssignAdd - auto assign_add = std::make_shared(kAssignAdd); + auto assign_add = std::make_shared(kAssignAddOpName); std::vector inputs; inputs.push_back(NewValueNode(assign_add)); inputs.push_back(switch_loop_input.at(kLoopCountParamName)); @@ -290,70 +353,9 @@ CNodePtr KernelAdjust::CreateStreamAssignAddnOP( selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL); MS_EXCEPTION_IF_NULL(switch_loop_input.at(kLoopCountParamName)); assign_add_one->set_abstract(switch_loop_input.at(kLoopCountParamName)->abstract()); - // set the distinction label of assign add - if (kernel_graph_ptr->execution_order().empty()) { - MS_LOG(EXCEPTION) << "empty execution order"; - } - auto first_node = kernel_graph_ptr->execution_order()[0]; - auto label = AnfAlgo::GetStreamDistinctionLabel(first_node.get()); - AnfAlgo::SetStreamDistinctionLabel(label, assign_add_one.get()); return assign_add_one; } -void KernelAdjust::SetStreamActiveOPs(const std::shared_ptr &kernel_graph_ptr, - const std::unordered_set &ctrl_stream_list, - const std::unordered_set &comm_stream_list, - const std::unordered_set &momentum_stream_list) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); - ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); - std::vector index_list; - index_list.clear(); - if (GetValue(active_target) == kValueTargetSwitch) { - index_list.insert(index_list.end(), ctrl_stream_list.begin(), ctrl_stream_list.end()); - } else if (GetValue(active_target) == kValueTargetOther) { - for (uint32_t index : comm_stream_list) { - if (AnfAlgo::GetStreamId(cnode_ptr) == index) { - continue; - } - index_list.emplace_back(index); - } - index_list.insert(index_list.end(), momentum_stream_list.begin(), momentum_stream_list.end()); - } - ValuePtr index_list_value = MakeValue(index_list); - AnfAlgo::SetNodeAttr(kAttrActiveStreamList, index_list_value, cnode_ptr); - } - } -} - -void KernelAdjust::SetStreamSwitchOps(const std::shared_ptr &kernel_graph_ptr) { - MS_EXCEPTION_IF_NULL(kernel_graph_ptr); - CNodePtr switch_cnode_ptr = nullptr; - uint32_t target_stream_id = 0; - for (const auto &cnode_ptr : kernel_graph_ptr->execution_order()) { - MS_EXCEPTION_IF_NULL(cnode_ptr); - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamSwitch) { - switch_cnode_ptr = cnode_ptr; - } - if (AnfAlgo::GetCNodeName(cnode_ptr) == kStreamActive) { - auto primitive = AnfAlgo::GetCNodePrimitive(cnode_ptr); - ValuePtr active_target = primitive->GetAttr(kAttrActiveTarget); - if (GetValue(active_target) == kValueTargetOther) { - target_stream_id = AnfAlgo::GetStreamId(cnode_ptr); - } - } - } - if (switch_cnode_ptr != nullptr) { - // set attr:true stream - ValuePtr true_index = MakeValue(target_stream_id); - AnfAlgo::SetNodeAttr(kAttrTrueBranchStream, true_index, switch_cnode_ptr); - MS_LOG(INFO) << "switch to true_index:" << target_stream_id; - } -} - bool KernelAdjust::StepLoadCtrlInputs(const std::shared_ptr &context, const std::shared_ptr &kernel_graph_ptr) { if (!NeedInsertSwitch()) { diff --git a/mindspore/ccsrc/device/kernel_adjust.h b/mindspore/ccsrc/device/kernel_adjust.h index ca01d51e54..3dced257c1 100644 --- a/mindspore/ccsrc/device/kernel_adjust.h +++ b/mindspore/ccsrc/device/kernel_adjust.h @@ -28,10 +28,22 @@ #include "session/session_context.h" #include "ir/meta_tensor.h" #include "device/ascend/profiling/profiling_utils.h" +#include "device/kernel_info.h" using mindspore::device::ascend::ProfilingTraceInfo; using mindspore::device::ascend::ProfilingUtils; namespace mindspore { +constexpr auto kLoopCountParamName = "loop_count"; +constexpr auto kIterLoopParamName = "iter_loop"; +constexpr auto kZeroParamName = "zero"; +constexpr auto kOneParamName = "one"; +constexpr auto kStreamNeedActivedFirst = "stream_need_active_first"; + +const uint32_t kFirstStreamSwitchLabel = kInvalidDistincLabel - 1; +const uint32_t kGetNextLabel = kInvalidDistincLabel - 2; +const uint32_t kSecondStreamSwitchLabel = kInvalidDistincLabel - 3; +const uint32_t kInvalidEventId = UINT32_MAX; +const uint32_t kFirstEventId = kInvalidEventId / 2; namespace device { class KernelAdjust { public: @@ -41,26 +53,23 @@ class KernelAdjust { } void Reorder(const std::shared_ptr &kernel_graph_ptr); void InsertSwitchLoop(const std::shared_ptr &kernel_graph_ptr); - void SetStreamActiveOPs(const std::shared_ptr &kernel_graph_ptr, - const std::unordered_set &ctrl_stream_list, - const std::unordered_set &comm_stream_list, - const std::unordered_set &momentum_stream_list); - void SetStreamSwitchOps(const std::shared_ptr &kernel_graph_ptr); bool StepLoadCtrlInputs(const std::shared_ptr &context, const std::shared_ptr &kernel_graph_ptr); void Profiling(NotNull kernel_graph_ptr); static bool NeedInsertSwitch(); - CNodePtr CreateSteamActiveOp(const std::shared_ptr &kernel_graph_ptr); + CNodePtr CreateStreamActiveOp(const std::shared_ptr &kernel_graph_ptr); private: KernelAdjust() = default; ~KernelAdjust() = default; + + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id); + uint32_t FindFirstStreamSwitchLabel(const std::shared_ptr &kernel_graph_ptr); void CreateSwitchOpParameters(const std::shared_ptr &kernel_graph_ptr, std::map *switch_loop_input); CNodePtr CreateStreamSwitchOp(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); - CNodePtr CreateStreamActiveSwitchOp(const std::shared_ptr &kernel_graph_ptr); - CNodePtr CreateStreamActiveOtherOp(const std::shared_ptr &kernel_graph_ptr); CNodePtr CreateStreamAssignAddnOP(const std::shared_ptr &kernel_graph_ptr, const std::map &switch_loop_input); kernel::KernelBuildInfo::KernelBuildInfoBuilder CreateMngKernelBuilder(const std::vector &formats, diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 6c245d7548..0de609f441 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -62,6 +62,7 @@ #include "pre_activate/ascend/format_type/insert_transdata_for_runop.h" #include "pre_activate/ascend/enhancer/getnext_memcpy_elimination.h" #include "pre_activate/ascend/ir_fission/addn_fission.h" +#include "pre_activate/ascend/enhancer/insert_memcpy_async_for_getnext.h" #include "utils/context/ms_context.h" #include "utils/config_manager.h" #include "debug/anf_ir_dump.h" @@ -187,6 +188,12 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); } + + if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + } optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); kernel_graph->SetExecOrderByDefault(); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc index d1409cdedd..77f6f96cec 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/stream_reuse.cc @@ -20,8 +20,8 @@ namespace mindspore { namespace memreuse { void StreamReuse::SetStreamReuseResource() { #ifdef ENABLE_D - auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().GetPhysicMap(); - auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().GetIndependentMap(); + auto logic_physic_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_physic_map(); + auto logic_independent_map = device::ascend::AscendStreamAssign::GetInstance().logic_to_independent_map(); MS_LOG(INFO) << "stream mem reuse for Davici"; if (!logic_independent_map.empty() && !logic_physic_map.empty()) { set_logic_physic_map(logic_physic_map); diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index ad6c58bc93..11ae3da6f7 100755 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -610,7 +610,7 @@ void AscendSession::CopyOutputOfIf(GraphId false_graph_id) { if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { // insert active in true graph, another active will be inserted in kernel adjust - InsertStreamActiveToGraph(true_last_id, kInvalidDistincLabel - 1); + InsertStreamActiveToGraph(true_last_id, kSecondStreamSwitchLabel); } break; } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index eac901b74d..eac1b86273 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -114,6 +114,9 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN"; constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum"; constexpr auto kBiasAddOpName = "BiasAdd"; constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad"; +constexpr auto kStreamSwitchOpName = "StreamSwitch"; +constexpr auto kStreamActiveOpName = "StreamActive"; +constexpr auto kAssignAddOpName = "AssignAdd"; constexpr auto kSendOpName = "Send"; constexpr auto kRecvOpName = "Recv"; constexpr auto kReluV2OpName = "ReluV2"; diff --git a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc index e0b5ab0d61..9c4fe2539d 100755 --- a/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc +++ b/tests/ut/cpp/stub/tasksink/ascend_stream_assign_stub.cc @@ -24,9 +24,7 @@ void AscendStreamAssign::AssignStreamNew(const KernelGraphPtr &graph) { return; uint32_t AscendStreamAssign::GetTotalStreamNum() const { return 1; } -std::vector AscendStreamAssign::GetWaitStreams() { return vector(); } - -std::vector AscendStreamAssign::GetHcomStreams() { return vector(); } +void AscendStreamAssign::GetWaitStreams(vector *wait_active_stream_list) { return; } namespace tasksink { bool TaskGenerator::GenTasks(const std::vector &anf_node_list, std::vector *const task_info_list, From 0e4824cd89a9e04cb613994a4203caf37934f108 Mon Sep 17 00:00:00 2001 From: lvliang Date: Mon, 20 Apr 2020 20:38:33 +0800 Subject: [PATCH 036/142] pynative-support-topk-and-print --- mindspore/ccsrc/device/kernel_runtime.cc | 3 ++- mindspore/ccsrc/device/kernel_runtime.h | 2 +- .../ccsrc/pre_activate/ascend/ascend_backend_optimization.cc | 1 + mindspore/ops/_op_impl/tbe/assign_add.py | 2 +- mindspore/ops/operations/debug_ops.py | 4 ++++ 5 files changed, 9 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 7f3d31d8d0..d1a068b584 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -135,10 +135,11 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { } void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, - const session::KernelGraph *graph) { + session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); // assign memory for input nodes RunOpAssignInputMemory(input_tensors, graph); + AssignStaticMemoryValueNode(graph); for (const auto &cnode : graph->execution_order()) { // assign memory for output nodes RunOpAssignOutputMemory(cnode); diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 61b43fd5c0..8f4f769f55 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -46,7 +46,7 @@ class KernelRuntime { virtual ~KernelRuntime(); virtual bool Init() = 0; virtual void AssignMemory(session::KernelGraph *graph); - void RunOpAssignMemory(const std::vector &input_tensors, const session::KernelGraph *graph); + void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); virtual bool Run(session::KernelGraph *graph); virtual bool DumpData(session::KernelGraph *graph); virtual bool RunTask(const session::KernelGraph *graph); diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 6c245d7548..be24a13582 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -215,6 +215,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr(); auto ir_fusion_pm = std::make_shared("ir_fusion_pm"); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ir_fusion_pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ops/_op_impl/tbe/assign_add.py b/mindspore/ops/_op_impl/tbe/assign_add.py index fbbb9a997f..2b20a7781d 100644 --- a/mindspore/ops/_op_impl/tbe/assign_add.py +++ b/mindspore/ops/_op_impl/tbe/assign_add.py @@ -25,7 +25,7 @@ assign_add_op_info = TBERegOp("AssignAdd") \ .partial_flag(True) \ .input(0, "ref", False, "required", "all") \ .input(1, "value", False, "required", "all") \ - .output(0, "output_ref", False, "required", "all") \ + .output(0, "ref", False, "required", "all") \ .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 1d8fdedc26..97fa883bac 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -210,6 +210,10 @@ class Print(PrimitiveWithInfer): def __init__(self): pass + def __call__(self, *args): + for arg in args: + print(arg) + def infer_shape(self, *inputs): return [1] From 930c91018af540d920dd29777fe61bb19b30bc61 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 14 Apr 2020 20:26:30 +0800 Subject: [PATCH 037/142] update some ops --- .../kernel/gpu/cuda_impl/unary_op_impl.cu | 15 ++++++++++++- .../kernel/gpu/cuda_impl/unary_op_impl.cuh | 3 ++- .../ccsrc/kernel/gpu/gpu_kernel_factory.cc | 6 +++-- .../kernel/gpu/math/binary_op_gpu_kernel.cc | 8 +++++++ .../kernel/gpu/math/binary_op_gpu_kernel.h | 22 ++++++++++++++----- .../kernel/gpu/math/unary_op_gpu_kernel.cc | 4 ++++ .../kernel/gpu/math/unary_op_gpu_kernel.h | 20 +++++++++++------ .../ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h | 6 +---- mindspore/ccsrc/vm/backend.cc | 6 +++++ 9 files changed, 69 insertions(+), 21 deletions(-) mode change 100755 => 100644 mindspore/ccsrc/vm/backend.cc diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu index 3cebefec17..6022485251 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu @@ -53,6 +53,13 @@ __global__ void ReciprocalKernel(T *input, T *output, size_t count) { return; } template +__global__ void SquareKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = input[i] * input[i]; + } + return; +} +template void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -72,12 +79,18 @@ void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ReciprocalKernel<<>>(input, output, count); return; } - +template +void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + SquareKernel<<>>(input, output, count); + return; +} template void Exponential(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Negative(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Square(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Exponential(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Negative(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Square(half *input, half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh index 2e7227eb32..f303c73d29 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh @@ -26,5 +26,6 @@ template void Negative(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); - +template +void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc index 21f5d084a9..fba2b24512 100644 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc @@ -41,8 +41,9 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel size_t attr_index) { if (kernel_info->GetInputNum() != iter_second->at(attr_index).first.GetInputSize()) { if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetInputAttr(0).first; for (size_t attr = 1; attr < kernel_info->GetInputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddInputAttr(kernel_info->GetInputDeviceType(0)); + (void)iter_second->at(attr_index).first.AddInputAttr(dtype); } } else { MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Input size is mismatching!"; @@ -50,8 +51,9 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel } if (kernel_info->GetOutputNum() != iter_second->at(attr_index).first.GetOutputSize()) { if (iter_second->at(attr_index).first.GetAllSame()) { + auto dtype = iter_second->at(attr_index).first.GetOutputAttr(0).first; for (size_t attr = 1; attr < kernel_info->GetOutputNum(); ++attr) { - (void)iter_second->at(attr_index).first.AddOutputAttr(kernel_info->GetOutputDeviceType(0)); + (void)iter_second->at(attr_index).first.AddOutputAttr(dtype); } } else { MS_LOG(EXCEPTION) << "op[" << kernel_name << "] Output size is mismatching!"; diff --git a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc index 56a0905e4e..4fe2acb726 100644 --- a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc @@ -38,5 +38,13 @@ MS_REG_GPU_KERNEL_ONE( MS_REG_GPU_KERNEL_ONE( Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), BinaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + BinaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + Maximum, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + BinaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h index 522ec2b37e..b929bbee50 100644 --- a/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h @@ -27,12 +27,16 @@ #include "kernel/gpu/kernel_constants.h" namespace mindspore { namespace kernel { -enum BinaryOpType { BINARY_OP_ADD = 0, BINARY_OP_SUB, BINARY_OP_MUL, BINARY_OP_DIV, BINARY_OP_INVALID_TYPE = 255 }; -const std::map kBinaryOpTypeMap = { - {"Sub", BINARY_OP_SUB}, - {"Mul", BINARY_OP_MUL}, - {"RealDiv", BINARY_OP_DIV}, +enum BinaryOpType { + BINARY_OP_ADD = 0, + BINARY_OP_SUB, + BINARY_OP_MUL, + BINARY_OP_DIV, + BINARY_OP_MAX, + BINARY_OP_INVALID_TYPE = 255 }; +static const std::map kBinaryOpTypeMap = { + {"Sub", BINARY_OP_SUB}, {"Mul", BINARY_OP_MUL}, {"RealDiv", BINARY_OP_DIV}, {"Maximum", BINARY_OP_MAX}}; template class BinaryOpGpuKernel : public GpuKernel { public: @@ -84,6 +88,10 @@ class BinaryOpGpuKernel : public GpuKernel { inputB_addr = workspace_addr; break; } + case BINARY_OP_MAX: { + inputB_addr = input_addr2; + break; + } default: { MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; } @@ -201,6 +209,10 @@ class BinaryOpGpuKernel : public GpuKernel { tensor_op_ = CUDNN_OP_TENSOR_ADD; break; } + case BINARY_OP_MAX: { + tensor_op_ = CUDNN_OP_TENSOR_MAX; + break; + } default: { MS_LOG(EXCEPTION) << "Binary operation " << binary_op_type_ << " is not supported."; } diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc index d69706663e..bfdbe11422 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc @@ -38,5 +38,9 @@ MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).A UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + UnaryOpGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h index af78ea4e73..5b2414f8f1 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h @@ -33,13 +33,15 @@ enum UnaryOptype { UNARY_OP_NEG, UNARY_OP_RECIPROCAL, UNARY_OP_ZEROSLIKE, + UNARY_OP_SQUARE, UNARY_OP_INVALID_TYPE = 255 }; -const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, - {"Log", UNARY_OP_LOG}, - {"Neg", UNARY_OP_NEG}, - {"Reciprocal", UNARY_OP_RECIPROCAL}, - {"ZerosLike", UNARY_OP_ZEROSLIKE}}; +static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, + {"Log", UNARY_OP_LOG}, + {"Neg", UNARY_OP_NEG}, + {"Reciprocal", UNARY_OP_RECIPROCAL}, + {"ZerosLike", UNARY_OP_ZEROSLIKE}, + {"Square", UNARY_OP_SQUARE}}; template class UnaryOpGpuKernel : public GpuKernel { public: @@ -74,6 +76,10 @@ class UnaryOpGpuKernel : public GpuKernel { Reciprocal(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_SQUARE: { + Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_ZEROSLIKE: { return true; } @@ -93,12 +99,12 @@ class UnaryOpGpuKernel : public GpuKernel { } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 1) { - MS_LOG(ERROR) << "Input number is " << input_num << ", but negative op needs 1 inputs."; + MS_LOG(ERROR) << "Input number is " << input_num << ", but unary op needs 1 inputs."; return false; } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { - MS_LOG(ERROR) << "Output number is " << output_num << ", but negative op needs 1 output."; + MS_LOG(ERROR) << "Output number is " << output_num << ", but unary op needs 1 output."; return false; } auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); diff --git a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h index 37d0aadfbc..975dbd0082 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h @@ -48,14 +48,10 @@ class FlattenGpuFwdKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { auto shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + input_size_ = sizeof(T); for (size_t i = 0; i < shape.size(); ++i) { - if (input_size_ == 0) { - input_size_ = 1; - } input_size_ *= shape[i]; } - input_size_ = input_size_ * sizeof(T); - InitSizeLists(); return true; } diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc old mode 100755 new mode 100644 index e69d25d2dc..d754667cce --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -189,6 +189,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) { } else if (utils::isa(arg)) { auto value = utils::cast(arg).object_; inputs.push_back(py::cast(value)); + } else if (utils::isa(arg)) { + auto args_new = utils::cast(arg); + (void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs), + [](const BaseRef &v) { return utils::cast(v); }); + } else { + MS_LOG(WARNING) << "Invalid input type."; } } From eb053a6233c4fe3d363ad20aa70f3e9711168063 Mon Sep 17 00:00:00 2001 From: xulei2020 <“xulei83@huawei.com”> Date: Sat, 18 Apr 2020 15:19:06 +0800 Subject: [PATCH 038/142] add filterOp code --- .../dataset/engine/datasetops/filter_op.cc | 48 ++++++------------- .../dataset/engine/datasetops/filter_op.h | 8 ++-- 2 files changed, 17 insertions(+), 39 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc index 22b1155fc9..e6662dea0f 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc @@ -116,20 +116,14 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { continue; } - // Thread local variables to avoid lock. When in_columns_ is empty and workers will write - // the name of the first column into input_columns (thread local) instead of in_columns_ (thread global). - std::vector input_columns = in_columns_; - // Indices of the columns to process. - std::vector to_process_indices; - - RETURN_IF_NOT_OK(WorkerEntryInit(in_buffer.get(), &to_process_indices, &input_columns)); + RETURN_IF_NOT_OK(CheckColumns(in_buffer.get(), &in_columns_)); // if the databuffer was all filtered, it is marked as kFilterEmpty. // if the databuffer was partially filtered, it is marked as kFilterPartial. // if the databuffer was not filtered, it is marked as kFilterFull. int32_t num_rows = in_buffer->NumRows(); std::unique_ptr new_tensor_table; - RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), to_process_indices, &new_tensor_table)); + RETURN_IF_NOT_OK(WorkerCompute(in_buffer.get(), &new_tensor_table)); if (new_tensor_table->empty()) { RETURN_IF_NOT_OK( @@ -147,17 +141,22 @@ Status FilterOp::WorkerEntry(int32_t worker_id) { return Status::OK(); } -Status FilterOp::WorkerCompute(DataBuffer *in_buffer, const std::vector &to_proess_indices, - std::unique_ptr *out) { +Status FilterOp::WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out) { *out = std::make_unique(); int32_t num_rows = in_buffer->NumRows(); for (int32_t i = 0; i < num_rows; i++) { TensorRow to_process; TensorRow cur_row; RETURN_IF_NOT_OK(in_buffer->PopRow(&cur_row)); - - (void)std::transform(to_proess_indices.begin(), to_proess_indices.end(), std::back_inserter(to_process), - [&cur_row](const size_t &it) -> std::shared_ptr { return cur_row[it]; }); + if (in_columns_.empty() == true) { + MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; + to_process = cur_row; + } else { + std::unordered_map col_map = in_buffer->column_name_map(); + (void)std::transform( + in_columns_.begin(), in_columns_.end(), std::back_inserter(to_process), + [&cur_row, &col_map](const auto &it) -> std::shared_ptr { return cur_row[col_map[it]]; }); + } bool predicate = true; RETURN_IF_NOT_OK(InvokePredicateFunc(to_process, &predicate)); if (predicate) { @@ -202,9 +201,8 @@ Status FilterOp::Collector() { return Status::OK(); } -// initialize some internal data structure used by WorkerEntry(). -Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector *to_process_indices, - std::vector *input_columns) { +// Private function for checking the column legality. +Status FilterOp::CheckColumns(const DataBuffer *in_buf, std::vector *input_columns) { int32_t num_rows = in_buf->NumRows(); int32_t num_cols = in_buf->NumCols(); if (num_rows == 0 || num_cols == 0) { @@ -213,24 +211,6 @@ Status FilterOp::WorkerEntryInit(const DataBuffer *in_buf, std::vector * std::unordered_map col_name_id_map = in_buf->column_name_map(); // Check if there is invalid column name in the inColumns. RETURN_IF_NOT_OK(ValidateInColumns(col_name_id_map, input_columns)); - - if (input_columns->empty()) { - MS_LOG(INFO) << "Input columns in filter operator is empty, will apply to the all column in the current table."; - // sort the input colunms by column index. - std::vector> sort_vec(col_name_id_map.begin(), col_name_id_map.end()); - std::sort(sort_vec.begin(), sort_vec.end(), - [](const std::pair &a, const std::pair &b) { - return a.second < b.second; - }); - - (void)std::transform(sort_vec.begin(), sort_vec.end(), std::back_inserter(*input_columns), - [](const auto &it) -> std::string { return it.first; }); - } - - // initialize to_process_indices. - (void)std::transform(input_columns->begin(), input_columns->end(), std::back_inserter(*to_process_indices), - [&col_name_id_map](const auto &it) -> size_t { return col_name_id_map[it]; }); - return Status::OK(); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h index 50697d398f..b182bf8ce6 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h @@ -141,8 +141,7 @@ class FilterOp : public ParallelOp { // @param to_proess_indices Indices of columns to be processed. // @param out data buffer that are filtered by predicate. // @return Status The error code return. - Status WorkerCompute(DataBuffer *in_buffer, const std::vector &to_proess_indices, - std::unique_ptr *out); + Status WorkerCompute(DataBuffer *in_buffer, std::unique_ptr *out); // Collector databuffer. // @return Status The error code return. @@ -166,13 +165,12 @@ class FilterOp : public ParallelOp { Status ValidateInColumns(const std::unordered_map &col_name_id_map, std::vector *input_columns); - // Private function that initialize some internal data structure used by WorkerEntry(). + // Private function for checking the column legality // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory // and is not shared with other threads. // @param[out] to_process_indices Indices of columns that will feed to predicate. // @param input_columns The vector of input column names used in the current thread. - Status WorkerEntryInit(const DataBuffer *in_buf, std::vector *to_process_indices, - std::vector *input_columns); + Status CheckColumns(const DataBuffer *in_buf, std::vector *input_columns); }; } // namespace dataset From 65a237633d11e580c00b073493c0a0a10abf4233 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Mon, 20 Apr 2020 19:30:35 +0800 Subject: [PATCH 039/142] fix bug in cross entropy error --- .../gpu/cuda_impl/cross_entropy_cuda_impl.cu | 47 --------------- .../gpu/cuda_impl/cross_entropy_cuda_impl.cuh | 26 -------- .../gpu/cuda_impl/cross_entropy_impl.cu | 59 ++++--------------- .../gpu/cuda_impl/cross_entropy_impl.cuh | 9 +-- ...max_cross_entropy_with_logits_gpu_kernel.h | 10 ++-- 5 files changed, 17 insertions(+), 134 deletions(-) delete mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu delete mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu deleted file mode 100644 index a3d2e3558c..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cu +++ /dev/null @@ -1,47 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include "cross_entropy_cuda_impl.cuh" -#include "include/cuda_runtime.h" - -__global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits, - const float *labels, const int batch_size, const int num_classes, - float *loss, float *dx) { - extern __shared__ float loss_shared[]; - const float mean_scale = 1.0f / static_cast(batch_size); - - loss_shared[threadIdx.x] = 0; - for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) { - loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i]; - dx[i] = (softmax_logits[i] - labels[i]) * mean_scale; - } - __syncthreads(); - if (threadIdx.x == 0) { - *loss = 0; - for (int i = 0; i < batch_size; i++) { - *loss += loss_shared[i]; - } - *loss *= mean_scale; - } -} - -void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, - const int batch_size, const int num_classes, float *loss, float *dx, - cudaStream_t cuda_stream) { - CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>( - softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx); -} diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh deleted file mode 100644 index 25b1624a46..0000000000 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_cuda_impl.cuh +++ /dev/null @@ -1,26 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ -#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ - -#include "device/gpu/cuda_common.h" - -void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels, - const int batch_size, const int num_classes, float *loss, float *dx, - cudaStream_t cuda_stream); - -#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu index 4d0503ba97..11c16581d6 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cu @@ -52,38 +52,12 @@ __global__ void CrossEntropyGradWithSparseKernel(const T *logits, const S *label } template -__global__ void CrossEntropyWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, - const size_t class_num, T *losses) { - T epsilon = 1e-6; - for (size_t i = 0; i < batch_size; ++i) { - T logit = 0.0; - for (size_t j = 0; j < class_num; j++) { - if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { - logit = logits[i * class_num + j]; - break; - } - } - if (logit <= 0) { - logit += epsilon; - } - losses[i] = -logf(logit); +__global__ void CrossEntropyKernel(const T *logits, const S *labels, const size_t class_num, T *losses, T *dlogits) { + losses[threadIdx.x] = 0; + for (int i = threadIdx.x * class_num; i < (threadIdx.x + 1) * class_num; ++i) { + losses[threadIdx.x] -= logf(logits[i]) * labels[i]; + dlogits[i] = logits[i] - labels[i]; } - return; -} - -template -__global__ void CrossEntropyGradWithoutSparseKernel(const T *logits, const S *labels, const size_t batch_size, - const size_t class_num, T *grad) { - for (size_t i = 0; i < batch_size; i++) { - for (size_t j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) { - if (fabs(labels[i * class_num + j] - 1.0) <= 1e-8) { - grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size; - } else { - grad[i * class_num + j] = logits[i * class_num + j] / batch_size; - } - } - } - return; } template @@ -102,18 +76,9 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b } template -void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *losses, cudaStream_t cuda_stream) { - CrossEntropyWithoutSparseKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, losses); - return; -} - -template -void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream) { - CrossEntropyGradWithoutSparseKernel<<>>( - logits, labels, batch_size, class_num, grad); - return; +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream) { + CrossEntropyKernel<<<1, batch_size, 0, cuda_stream>>>(logits, labels, class_num, losses, dlogits); } template void CrossEntropyWithSparse(const float *logits, const int *labels, const size_t batch_size, @@ -126,8 +91,6 @@ template void CrossEntropyGradWithSparse(const float *logits, const template void CrossEntropyGradWithSparse(const float *logits, const int64_t *labels, const size_t batch_size, const size_t class_num, float *grad, cudaStream_t cuda_stream); -template void CrossEntropyWithoutSparse(const float *logits, const float *labels, const size_t batch_size, - const size_t class_num, float *losses, cudaStream_t cuda_stream); -template void CrossEntropyGradWithoutSparse(const float *logits, const float *labels, - const size_t batch_size, const size_t class_num, float *grad, - cudaStream_t cuda_stream); +template void CrossEntropy(const float *logits, const float *labels, const size_t batch_size, + const size_t class_num, float *losses, float *dlogits, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh index 00ec13553d..54ae072892 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/cross_entropy_impl.cuh @@ -28,11 +28,6 @@ void CrossEntropyGradWithSparse(const T *logits, const S *labels, const size_t b T *grad, cudaStream_t cuda_stream); template -void CrossEntropyWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *losses, cudaStream_t cuda_stream); - -template -void CrossEntropyGradWithoutSparse(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, - T *grad, cudaStream_t cuda_stream); - +void CrossEntropy(const T *logits, const S *labels, const size_t batch_size, const size_t class_num, T *losses, + T *dlogits, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPY_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h index 3822a326fb..4d50d4753d 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/softmax_cross_entropy_with_logits_gpu_kernel.h @@ -58,8 +58,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { } T *logits_addr = GetDeviceAddress(inputs, 0); S *labels_addr = GetDeviceAddress(inputs, 1); - T *output1_addr = GetDeviceAddress(outputs, 0); - T *output2_addr = GetDeviceAddress(outputs, 1); + T *loss_addr = GetDeviceAddress(outputs, 0); + T *dlogits_addr = GetDeviceAddress(outputs, 1); T *softmax_output_logits = GetDeviceAddress(workspace, 0); const float alpha = 1; @@ -69,10 +69,8 @@ class SoftmaxCrossEntropyWithLogitsGpuKernel : public GpuKernel { softmax_output_descriptor_, softmax_output_logits), "cudnnSoftmaxForward failed."); - CrossEntropyWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output1_addr, - reinterpret_cast(stream_ptr)); - CrossEntropyGradWithoutSparse(softmax_output_logits, labels_addr, batch_size_, channel_size_, output2_addr, - reinterpret_cast(stream_ptr)); + CrossEntropy(softmax_output_logits, labels_addr, batch_size_, channel_size_, loss_addr, dlogits_addr, + reinterpret_cast(stream_ptr)); return true; } bool Init(const CNodePtr &kernel_node) override { From 94c99998ae3556e0a54dc5637643c568e0fd9263 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 03:17:24 -0400 Subject: [PATCH 040/142] add AvgPooling layer --- mindspore/nn/layer/__init__.py | 4 +- mindspore/nn/layer/pooling.py | 84 ++++++++++++++++++++++++++++++ tests/ut/python/nn/test_pooling.py | 16 ++++++ 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index 098489a91d..b9f79b6cf7 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -24,7 +24,7 @@ from .conv import Conv2d, Conv2dTranspose from .lstm import LSTM from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, Pad, Unfold from .embedding import Embedding -from .pooling import AvgPool2d, MaxPool2d +from .pooling import AvgPool2d, MaxPool2d, AvgPool1d from .image import ImageGradients, SSIM, PSNR __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', @@ -35,6 +35,6 @@ __all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'LSTM', 'Dropout', 'Flatten', 'Dense', 'ClipByNorm', 'Norm', 'OneHot', 'Embedding', - 'AvgPool2d', 'MaxPool2d', 'Pad', 'Unfold', + 'AvgPool2d', 'MaxPool2d', 'AvgPool1d', 'Pad', 'Unfold', 'ImageGradients', 'SSIM', 'PSNR', ] diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 53d97807cf..17700ff7b4 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -208,3 +208,87 @@ class AvgPool2d(_PoolNd): def construct(self, x): return self.avg_pool(x) + + +class AvgPool1d(_PoolNd): + r""" + Average pooling for temporal data. + + Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D input planes. + + Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool2d outputs + regional average in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size + :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows. + + .. math:: + \text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1} + \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) + + Note: + pad_mode for training only supports "same" and "valid". + + Args: + kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value, + is an int number that represents height and width are both kernel_size, + or a tuple of two int numbers that represent height and width respectively. + Default: 1. + stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents + the height and width of movement are both strides, or a tuple of two int numbers that + represent height and width of movement respectively. Default: 1. + pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive. + Default: "valid". + + - same: Adopts the way of completion. Output height and width will be the same as + the input. Total number of padding will be calculated for horizontal and vertical + direction and evenly distributed to top and bottom, left and right if possible. + Otherwise, the last extra padding will be done from the bottom and the right side. + + - valid: Adopts the way of discarding. The possibly largest height and width of output + will be return without padding. Extra pixels will be discarded. + + + Inputs: + - **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> pool = nn.AvgPool2d(kernel_size=3, strides=1) + >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) + [[[[5. 5. 9. 9.] + [8. 4. 3. 0.] + [2. 7. 1. 2.] + [1. 8. 3. 3.]] + [[6. 8. 2. 4.] + [3. 0. 2. 1.] + [0. 8. 9. 7.] + [2. 1. 4. 9.]]]] + >>> output = pool(x) + >>> output.shape() + (1, 2, 2, 2) + >>> output + [[[[4.888889 4.4444447] + [4.111111 3.4444444]] + [[4.2222223 4.5555553] + [3.2222223 4.5555553]]]] + """ + + def __init__(self, + kernel_size=1, + stride=1, + pad_mode="valid"): + super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) + if not isinstance(kernel_size, int): + raise ValueError("kernel_size should be 1 int number but got {}". + format(kernel_size)) + if not isinstance(stride, int): + raise ValueError("stride should be 1 int number but got {}".format(stride)) + self.kernel_size = (1, kernel_size) + self.stride = (1, stride) + self.avg_pool = P.AvgPool(ksize=self.kernel_size, + strides=self.stride, + padding=self.pad_mode) + + def construct(self, x): + return self.avg_pool(x) diff --git a/tests/ut/python/nn/test_pooling.py b/tests/ut/python/nn/test_pooling.py index 10bb7632b2..428e050ea2 100644 --- a/tests/ut/python/nn/test_pooling.py +++ b/tests/ut/python/nn/test_pooling.py @@ -56,3 +56,19 @@ def test_compile_max(): net = MaxNet(3, stride=1, padding=0) x = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) _executor.compile(net, x) + + +class Avg1dNet(nn.Cell): + def __init__(self, + kernel_size, + stride=None): + super(Avg1dNet, self).__init__() + self.avg1d = nn.AvgPool1d(kernel_size, stride) + + def construct(self, x): + return self.avg1d(x) + +def test_avg1d(): + net = Avg1dNet(3, 1) + input = Tensor(np.random.randint(0, 255, [1, 3, 6, 6]).astype(np.float32)) + _executor.compile(net, input) \ No newline at end of file From c0229fa951b1ac537fd2e83eb3d539f3911198e6 Mon Sep 17 00:00:00 2001 From: chenzomi Date: Tue, 21 Apr 2020 15:27:29 +0800 Subject: [PATCH 041/142] change hswish and hsigmoid accroding to primitive --- mindspore/_akg/gpu/__init__.py | 8 ++++---- mindspore/_akg/gpu/hsigmoid.py | 8 ++++---- mindspore/_akg/gpu/hsigmoid_grad.py | 8 ++++---- mindspore/_akg/gpu/hswish.py | 8 ++++---- mindspore/_akg/gpu/hswish_grad.py | 10 +++++----- .../predict/converter/lite_model/op_attr_packer.cc | 4 ++-- mindspore/nn/layer/activation.py | 2 +- mindspore/ops/_op_impl/akg/gpu/__init__.py | 4 ++++ mindspore/ops/operations/nn_ops.py | 2 +- 9 files changed, 29 insertions(+), 25 deletions(-) diff --git a/mindspore/_akg/gpu/__init__.py b/mindspore/_akg/gpu/__init__.py index 2ac6d1adb1..08961d3989 100644 --- a/mindspore/_akg/gpu/__init__.py +++ b/mindspore/_akg/gpu/__init__.py @@ -26,7 +26,7 @@ from .squeeze_grad import SqueezeGrad, gpu_schedule_SqueezeGrad from .mean import SimpleMean, gpu_schedule_SimpleMean from .mean_grad import SimpleMeanGrad, gpu_schedule_SimpleMeanGrad from .mul import Mul, gpu_schedule_Mul -from .hsigmoid import Hsigmoid, gpu_schedule_Hsigmoid -from .hsigmoid_grad import HsigmoidGrad, gpu_schedule_HsigmoidGrad -from .hswish import Hswish, gpu_schedule_Hswish -from .hswish_grad import HswishGrad, gpu_schedule_HswishGrad +from .hsigmoid import HSigmoid, gpu_schedule_HSigmoid +from .hsigmoid_grad import HSigmoidGrad, gpu_schedule_HSigmoidGrad +from .hswish import HSwish, gpu_schedule_HSwish +from .hswish_grad import HSwishGrad, gpu_schedule_HSwishGrad diff --git a/mindspore/_akg/gpu/hsigmoid.py b/mindspore/_akg/gpu/hsigmoid.py index b9d5ea74c9..b313c2fd5a 100644 --- a/mindspore/_akg/gpu/hsigmoid.py +++ b/mindspore/_akg/gpu/hsigmoid.py @@ -33,9 +33,9 @@ def topi_nn_hsigmoid(x): (x(*i) + 3) / 6))) -def Hsigmoid(x): +def HSigmoid(x): """ - Hsigmoid + HSigmoid Args: x: @@ -45,9 +45,9 @@ def Hsigmoid(x): return topi_nn_hsigmoid(x) -def gpu_schedule_Hsigmoid(outs): +def gpu_schedule_HSigmoid(outs): """ - gpu schedule Hsigmoid + gpu schedule HSigmoid Args: outs: diff --git a/mindspore/_akg/gpu/hsigmoid_grad.py b/mindspore/_akg/gpu/hsigmoid_grad.py index d3e7ac6345..bdde4ed3ca 100644 --- a/mindspore/_akg/gpu/hsigmoid_grad.py +++ b/mindspore/_akg/gpu/hsigmoid_grad.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Hsigmoid grad""" +"""HSigmoid grad""" import _akg.topi as topi import _akg.tvm as tvm -def HsigmoidGrad(y_grad, x): +def HSigmoidGrad(y_grad, x): """ - HsigmoidGrad + HSigmoidGrad Args: y_grad: x: @@ -32,7 +32,7 @@ def HsigmoidGrad(y_grad, x): y_grad(*i) / 6))) -def gpu_schedule_HsigmoidGrad(outs): +def gpu_schedule_HSigmoidGrad(outs): """ gpu schedule ReLU6Grad Args: diff --git a/mindspore/_akg/gpu/hswish.py b/mindspore/_akg/gpu/hswish.py index 904c38c2a2..44fcf10918 100644 --- a/mindspore/_akg/gpu/hswish.py +++ b/mindspore/_akg/gpu/hswish.py @@ -33,9 +33,9 @@ def topi_nn_hswish(x): x(*i) * (x(*i) + 3) / 6))) -def Hswish(x): +def HSwish(x): """ - Hswish + HSwish Args: x: @@ -45,9 +45,9 @@ def Hswish(x): return topi_nn_hswish(x) -def gpu_schedule_Hswish(outs): +def gpu_schedule_HSwish(outs): """ - gpu schedule Hswish + gpu schedule HSwish Args: outs: diff --git a/mindspore/_akg/gpu/hswish_grad.py b/mindspore/_akg/gpu/hswish_grad.py index 5b38f07c84..cadbf0f663 100644 --- a/mindspore/_akg/gpu/hswish_grad.py +++ b/mindspore/_akg/gpu/hswish_grad.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""HswishGrad""" +"""HSwishGrad""" import _akg.topi as topi import _akg.tvm as tvm -def HswishGrad(y_grad, x): +def HSwishGrad(y_grad, x): """ - HswishGrad + HSwishGrad Args: y_grad: x: @@ -34,9 +34,9 @@ def HswishGrad(y_grad, x): return res6 -def gpu_schedule_HswishGrad(outs): +def gpu_schedule_HSwishGrad(outs): """ - gpu schedule HswishGrad + gpu schedule HSwishGrad Args: outs: diff --git a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc index f186758de5..e6fec3d540 100644 --- a/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc +++ b/mindspore/ccsrc/predict/converter/lite_model/op_attr_packer.cc @@ -48,8 +48,8 @@ OpAttrFactory::OpAttrFactory() { {"Softsign", ActivationPacker}, {"Softplus", ActivationPacker}, {"Tanh", ActivationPacker}, - {"Hswish", ActivationPacker}, - {"Hsigmoid", ActivationPacker}, + {"HSwish", ActivationPacker}, + {"HSigmoid", ActivationPacker}, {"MaxPool", PoolingPacker}, {"MaxPool2D", PoolingPacker}, {"MeanPool", PoolingPacker}, diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index 6485e27228..8845247a65 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -346,7 +346,7 @@ class HSwish(Cell): where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. Inputs: - - **input_data** (Tensor) - The input of Hswish. + - **input_data** (Tensor) - The input of HSwish. Outputs: Tensor, with the same type and shape as the `input_data`. diff --git a/mindspore/ops/_op_impl/akg/gpu/__init__.py b/mindspore/ops/_op_impl/akg/gpu/__init__.py index 2135794b5f..8ffc796ae3 100644 --- a/mindspore/ops/_op_impl/akg/gpu/__init__.py +++ b/mindspore/ops/_op_impl/akg/gpu/__init__.py @@ -23,3 +23,7 @@ from .relu6_grad import _relu6_grad_akg from .squeeze import _squeeze_akg from .squeeze_grad import _squeeze_grad_akg from .tile import _tile_akg +from .hsigmoid import _hsigmoid_akg +from .hsigmoid_grad import _hsigmoid_grad_akg +from .hswish import _hswish_akg +from .hswish_grad import _hswish_grad_akg diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 1fb65e3b76..bc88316ee5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -258,7 +258,7 @@ class HSwish(PrimitiveWithInfer): where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. Inputs: - - **input_data** (Tensor) - The input of Hswish. + - **input_data** (Tensor) - The input of HSwish. Outputs: Tensor, with the same type and shape as the `input_data`. From e170a0355ce78f671183a11c4ce2b538d5e42864 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 03:50:31 -0400 Subject: [PATCH 042/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 37 ++++++++++------------------------- 1 file changed, 10 insertions(+), 27 deletions(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 17700ff7b4..299891232e 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -214,27 +214,23 @@ class AvgPool1d(_PoolNd): r""" Average pooling for temporal data. - Applies a 2D average pooling over an input Tensor which can be regarded as a composition of 2D input planes. + Applies a 1D average pooling over an input Tensor which can be regarded as a composition of 1D input planes. - Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool2d outputs - regional average in the :math:`(H_{in}, W_{in})`-dimension. Given kernel size - :math:`ks = (h_{ker}, w_{ker})` and stride :math:`s = (s_0, s_1)`, the operation is as follows. + Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs + regional average in the :math:`(W_{in})`-dimension. Given kernel size + :math:`ks = (w_{ker})` and stride :math:`s = (s_0)`, the operation is as follows. .. math:: - \text{output}(N_i, C_j, h, w) = \frac{1}{h_{ker} * w_{ker}} \sum_{m=0}^{h_{ker}-1} \sum_{n=0}^{w_{ker}-1} - \text{input}(N_i, C_j, s_0 \times h + m, s_1 \times w + n) + \text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1} + \text{input}(N_i, C_j, h_k, s_0 \times w + n) Note: pad_mode for training only supports "same" and "valid". Args: - kernel_size (Union[int, tuple[int]]): The size of kernel used to take the average value, - is an int number that represents height and width are both kernel_size, - or a tuple of two int numbers that represent height and width respectively. - Default: 1. - stride (Union[int, tuple[int]]): The distance of kernel moving, an int number that represents - the height and width of movement are both strides, or a tuple of two int numbers that - represent height and width of movement respectively. Default: 1. + kernel_size (int): The size of kernel window used to take the average value, Default: 1. + stride (int): The distance of kernel moving, an int number that represents + the width of movement is strides, Default: 1. pad_mode (str): The optional values for pad mode, is "same" or "valid", not case sensitive. Default: "valid". @@ -254,24 +250,11 @@ class AvgPool1d(_PoolNd): Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> pool = nn.AvgPool2d(kernel_size=3, strides=1) + >>> pool = nn.AvgPool1d(kernel_size=3, strides=1) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) - [[[[5. 5. 9. 9.] - [8. 4. 3. 0.] - [2. 7. 1. 2.] - [1. 8. 3. 3.]] - [[6. 8. 2. 4.] - [3. 0. 2. 1.] - [0. 8. 9. 7.] - [2. 1. 4. 9.]]]] >>> output = pool(x) >>> output.shape() - (1, 2, 2, 2) >>> output - [[[[4.888889 4.4444447] - [4.111111 3.4444444]] - [[4.2222223 4.5555553] - [3.2222223 4.5555553]]]] """ def __init__(self, From 6c87c6c03d5cce192ca177fd52ef47c7c7022c97 Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Tue, 21 Apr 2020 16:33:58 +0800 Subject: [PATCH 043/142] predict use cmake -s flags rather than strip --- predict/CMakeLists.txt | 1 + predict/src/CMakeLists.txt | 14 -------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/predict/CMakeLists.txt b/predict/CMakeLists.txt index 2641932769..39ca6b27e8 100755 --- a/predict/CMakeLists.txt +++ b/predict/CMakeLists.txt @@ -6,6 +6,7 @@ set(CMAKE_BUILD_TYPE "Release") set(CMAKE_CXX_STANDARD 11) set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=hidden") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") +set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -s") option(ENABLE_ASAN "Enable Google Sanitizer to find memory bugs" OFF) option(ENABLE_PREDICT_ARM64 "predict arm64" OFF) diff --git a/predict/src/CMakeLists.txt b/predict/src/CMakeLists.txt index c32c047c82..92c45473d7 100644 --- a/predict/src/CMakeLists.txt +++ b/predict/src/CMakeLists.txt @@ -52,20 +52,6 @@ else() target_link_libraries(mspredict pthread tvm_kernel libsecurec.a) endif() -if("${CMAKE_BUILD_TYPE}" STREQUAL "Release") - if(ENABLE_PREDICT_ARM64) - add_custom_command(TARGET mspredict POST_BUILD - COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip "${PREDICT_BUILD_DIR}/src/libmspredict.so" - COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip "${PREDICT_BUILD_DIR}/module/tvm_kernel/lite/libtvm_kernel.so" - ) - else() - add_custom_command(TARGET mspredict POST_BUILD - COMMAND strip "${PREDICT_BUILD_DIR}/src/libmspredict.so" - COMMAND strip "${PREDICT_BUILD_DIR}/module/tvm_kernel/lite/libtvm_kernel.so" - ) - endif() -endif() - add_dependencies(mspredict tvm_kernel) add_dependencies(mspredict securec) add_dependencies(mspredict gtest) From 92695a0da8a908a9ee86fb44827ef6b359484783 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 04:44:33 -0400 Subject: [PATCH 044/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 299891232e..fef9494ea4 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -218,7 +218,7 @@ class AvgPool1d(_PoolNd): Typically the input is of shape :math:`(N_{in}, C_{in}, H_{in}, W_{in})`, AvgPool1d outputs regional average in the :math:`(W_{in})`-dimension. Given kernel size - :math:`ks = (w_{ker})` and stride :math:`s = (s_0)`, the operation is as follows. + :math:`ks = w_{ker}` and stride :math:`s = s_0`, the operation is as follows. .. math:: \text{output}(N_i, C_j, h_k, w) = \frac{1}{w_{ker}} \sum_{n=0}^{w_{ker}-1} From bdbb3599b2168fa3612285948647deacdcf7ade1 Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Tue, 21 Apr 2020 07:18:03 +0000 Subject: [PATCH 045/142] Check value between min_lr and max_lr --- mindspore/nn/dynamic_lr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 0c5a160380..bb23d6275d 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -233,6 +233,8 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) + if min_lr >= max_lr: + raise ValueError('`max_lr` should be greater than `min_lr`.') delta = 0.5 * (max_lr - min_lr) lr = [] From 834a407103ddeb49df8c77b56787ba1fc43db110 Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Tue, 21 Apr 2020 08:48:17 +0000 Subject: [PATCH 046/142] Add the function of checking nan or inf --- mindspore/_checkparam.py | 11 +++++++++++ mindspore/nn/dynamic_lr.py | 23 ++++++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 7b8c89351c..ae42741371 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -15,6 +15,7 @@ """Check parameters.""" import re import inspect +import math from enum import Enum from functools import reduce, wraps from itertools import repeat @@ -318,6 +319,16 @@ class Validator: raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' f' but got {get_typename(arg_type)}.') + @staticmethod + def check_float_legal_value(arg_name, arg_value, prim_name): + """Checks whether a legal value of float type""" + msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" + if isinstance(arg_value, float): + if math.isinf(arg_value) or math.isnan(arg_value): + raise ValueError(f"{msg_prefix} `{arg_name}` must be legal value, but got {arg_value}.") + return arg_value + raise TypeError(f"{msg_prefix} `{arg_name}` must be float.") + class ParamValidator: """Parameter validator. NOTICE: this class will be replaced by `class Validator`""" diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 0c5a160380..dbc23ecfdc 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -28,7 +28,7 @@ def piecewise_constant_lr(milestone, learning_rates): `milestone`. Let the output learning rate be `y`. .. math:: - y[i] = x_t for i \in [M_{t-1}, M_t) + y[i] = x_t,\ for\ i \in [M_{t-1}, M_t) Args: milestone (list[int]): A list of milestone. This list is a monotone increasing list. @@ -52,7 +52,7 @@ def piecewise_constant_lr(milestone, learning_rates): last_item = 0 for i, item in enumerate(milestone): validator.check_integer(f'milestone[{i}]', item, 0, Rel.GT, None) - validator.check_value_type(f'learning_rates[{i}]', learning_rates[i], [float], None) + validator.check_float_legal_value(f'learning_rates[{i}]', learning_rates[i], None) if item < last_item: raise ValueError(f'The value of milestone[{i}] must be greater than milestone[{i - 1}]') lr += [learning_rates[i]] * (item - last_item) @@ -66,7 +66,9 @@ def _check_inputs(learning_rate, decay_rate, total_step, step_per_epoch, decay_e validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) validator.check_float_positive('decay_rate', decay_rate, None) + validator.check_float_legal_value('decay_rate', decay_rate, None) validator.check_value_type('is_stair', is_stair, [bool], None) @@ -229,7 +231,9 @@ def cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch): [0.1, 0.1, 0.05500000000000001, 0.05500000000000001, 0.01, 0.01] """ validator.check_float_positive('min_lr', min_lr, None) + validator.check_float_legal_value('min_lr', min_lr, None) validator.check_float_positive('max_lr', max_lr, None) + validator.check_float_legal_value('max_lr', max_lr, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) @@ -280,11 +284,14 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e [0.1, 0.1, 0.07363961030678928, 0.07363961030678928, 0.01, 0.01] """ validator.check_float_positive('learning_rate', learning_rate, None) + validator.check_float_legal_value('learning_rate', learning_rate, None) validator.check_float_positive('end_learning_rate', end_learning_rate, None) + validator.check_float_legal_value('end_learning_rate', end_learning_rate, None) + validator.check_float_positive('power', power, None) + validator.check_float_legal_value('power', power, None) validator.check_integer('total_step', total_step, 0, Rel.GT, None) validator.check_integer('step_per_epoch', step_per_epoch, 0, Rel.GT, None) validator.check_integer('decay_epoch', decay_epoch, 0, Rel.GT, None) - validator.check_value_type('power', power, [float], None) validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) function = lambda x, y: (x, min(x, y)) @@ -298,3 +305,13 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e decay_epoch, tmp_epoch = function(decay_epoch, current_epoch) lr.append(delta * (1 - tmp_epoch / decay_epoch) ** power + end_learning_rate) return lr + + +__all__ = [ + 'piecewise_constant_lr', + 'exponential_decay_lr', + 'natural_exp_decay_lr', + 'inverse_decay_lr', + 'cosine_decay_lr', + 'polynomial_decay_lr' +] From db80f4ff928213b08ec2e49b21a90c2e707a7467 Mon Sep 17 00:00:00 2001 From: qianlong Date: Mon, 20 Apr 2020 21:27:11 +0800 Subject: [PATCH 047/142] The num_samples and numRows in schema for TFRecordDataset are conflict --- .../datasetops/source/storage_client.cc | 6 ++- .../engine/datasetops/source/tf_reader_op.cc | 3 ++ mindspore/dataset/engine/datasets.py | 12 +++-- .../datasetSchemaNoRow.json | 45 +++++++++++++++++++ .../datasetNoRowsSchema.json | 15 +++++++ tests/ut/python/dataset/test_storage.py | 12 +++++ tests/ut/python/dataset/test_tfreader_op.py | 30 +++++++++++++ 7 files changed, 119 insertions(+), 4 deletions(-) create mode 100644 tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json create mode 100644 tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc index 862edcf63a..7f081af2b7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/storage_client.cc @@ -162,7 +162,11 @@ Status StorageClient::numRowsFromFile(uint32_t &num_rows) const { std::ifstream in(schemaFile); nlohmann::json js; in >> js; - num_rows = js.value("numRows", 0); + if (js.find("numRows") == js.end()) { + num_rows = MAX_INTEGER_INT32; + } else { + num_rows = js.value("numRows", 0); + } if (num_rows == 0) { std::string err_msg = "Storage client has not properly done dataset " diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc index a72be1f703..6132f628d7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/tf_reader_op.cc @@ -163,6 +163,9 @@ Status TFReaderOp::Init() { if (total_rows_ == 0) { total_rows_ = data_schema_->num_rows(); } + if (total_rows_ < 0) { + RETURN_STATUS_UNEXPECTED("The num_sample or numRows for TFRecordDataset should be greater than 0"); + } // Build the index with our files such that each file corresponds to a key id. RETURN_IF_NOT_OK(filename_index_->insert(dataset_files_list_)); diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 28697a6c43..855e4609bb 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1455,7 +1455,7 @@ class StorageDataset(SourceDataset): Args: dataset_files (list[str]): List of files to be read. - schema (str): Path to the json schema file. + schema (str): Path to the json schema file. If numRows(parsed from schema) is not exist, read the full dataset. distribution (str, optional): Path of distribution config file (default=""). columns_list (list[str], optional): List of columns to be read (default=None, read all columns). num_parallel_workers (int, optional): Number of parallel working threads (default=None). @@ -2193,7 +2193,10 @@ class TFRecordDataset(SourceDataset): schema (str or Schema, optional): Path to the json schema file or schema object (default=None). If the schema is not provided, the meta data from the TFData file is considered the schema. columns_list (list[str], optional): List of columns to be read (default=None, read all columns) - num_samples (int, optional): number of samples(rows) to read (default=None, reads the full dataset). + num_samples (int, optional): number of samples(rows) to read (default=None). + If num_samples is None and numRows(parsed from schema) is not exist, read the full dataset; + If num_samples is None and numRows(parsed from schema) is greater than 0, read numRows rows; + If both num_samples and numRows(parsed from schema) are greater than 0, read num_samples rows. num_parallel_workers (int, optional): number of workers to read the data (default=None, number set in the config). shuffle (bool, Shuffle level, optional): perform reshuffling of the data every epoch (default=Shuffle.GLOBAL). @@ -2711,10 +2714,10 @@ class Schema: """ def __init__(self, schema_file=None): + self.num_rows = None if schema_file is None: self.columns = [] self.dataset_type = '' - self.num_rows = 0 else: if not os.path.isfile(schema_file) or not os.access(schema_file, os.R_OK): raise ValueError("The file %s does not exist or permission denied!" % schema_file) @@ -2859,6 +2862,9 @@ class Schema: raise RuntimeError("DatasetType field is missing.") if self.columns is None: raise RuntimeError("Columns are missing.") + if self.num_rows is not None: + if not isinstance(self.num_rows, int) or self.num_rows <= 0: + raise ValueError("numRows must be greater than 0") def __str__(self): return self.to_json() diff --git a/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json b/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json new file mode 100644 index 0000000000..92abf66ef8 --- /dev/null +++ b/tests/ut/data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json @@ -0,0 +1,45 @@ +{ + "datasetType": "TF", + "columns": { + "col_sint16": { + "type": "int16", + "rank": 1, + "shape": [1] + }, + "col_sint32": { + "type": "int32", + "rank": 1, + "shape": [1] + }, + "col_sint64": { + "type": "int64", + "rank": 1, + "shape": [1] + }, + "col_float": { + "type": "float32", + "rank": 1, + "shape": [1] + }, + "col_1d": { + "type": "int64", + "rank": 1, + "shape": [2] + }, + "col_2d": { + "type": "int64", + "rank": 2, + "shape": [2, 2] + }, + "col_3d": { + "type": "int64", + "rank": 3, + "shape": [2, 2, 2] + }, + "col_binary": { + "type": "uint8", + "rank": 1, + "shape": [1] + } + } +} diff --git a/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json b/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json new file mode 100644 index 0000000000..e00fd39c10 --- /dev/null +++ b/tests/ut/data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json @@ -0,0 +1,15 @@ +{ + "datasetType": "TF", + "columns": { + "image": { + "type": "uint8", + "rank": 1, + "t_impl": "cvmat" + }, + "label" : { + "type": "uint64", + "rank": 1, + "t_impl": "flex" + } + } +} diff --git a/tests/ut/python/dataset/test_storage.py b/tests/ut/python/dataset/test_storage.py index b37a52f37d..92a689a689 100644 --- a/tests/ut/python/dataset/test_storage.py +++ b/tests/ut/python/dataset/test_storage.py @@ -37,3 +37,15 @@ def test_case_storage(): filename = "storage_result.npz" save_and_check(data1, parameters, filename, generate_golden=GENERATE_GOLDEN) + + +def test_case_no_rows(): + DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] + SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetNoRowsSchema.json" + + dataset = ds.StorageDataset(DATA_DIR, SCHEMA_DIR, columns_list=["image"]) + assert dataset.get_dataset_size() == 3 + count = 0 + for data in dataset.create_tuple_iterator(): + count += 1 + assert count == 3 diff --git a/tests/ut/python/dataset/test_tfreader_op.py b/tests/ut/python/dataset/test_tfreader_op.py index 3add50e1cb..c5d9471f8b 100644 --- a/tests/ut/python/dataset/test_tfreader_op.py +++ b/tests/ut/python/dataset/test_tfreader_op.py @@ -37,6 +37,36 @@ def test_case_tf_shape(): assert (len(output_shape[-1]) == 1) +def test_case_tf_read_all_dataset(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchemaNoRow.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 12 + count = 0 + for data in ds1.create_tuple_iterator(): + count += 1 + assert count == 12 + + +def test_case_num_samples(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file, num_samples=8) + assert ds1.get_dataset_size() == 8 + count = 0 + for data in ds1.create_dict_iterator(): + count += 1 + assert count == 8 + + +def test_case_num_samples2(): + schema_file = "../data/dataset/testTFTestAllTypes/datasetSchema7Rows.json" + ds1 = ds.TFRecordDataset(FILES, schema_file) + assert ds1.get_dataset_size() == 7 + count = 0 + for data in ds1.create_dict_iterator(): + count += 1 + assert count == 7 + + def test_case_tf_shape_2(): ds1 = ds.TFRecordDataset(FILES, SCHEMA_FILE) ds1 = ds1.batch(2) From 7947fc119b6d5883f4bb283f7d595d893ed812fe Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 21 Apr 2020 16:32:58 +0800 Subject: [PATCH 048/142] gpu add lessequal --- mindspore/_akg/gpu/less_equal.py | 40 +++++++++++++++ mindspore/_akg/ops/math/less_equal.py | 54 +++++++++++++++++++++ mindspore/ops/_op_impl/akg/gpu/lessequal.py | 32 ++++++++++++ 3 files changed, 126 insertions(+) create mode 100644 mindspore/_akg/gpu/less_equal.py create mode 100644 mindspore/_akg/ops/math/less_equal.py create mode 100644 mindspore/ops/_op_impl/akg/gpu/lessequal.py diff --git a/mindspore/_akg/gpu/less_equal.py b/mindspore/_akg/gpu/less_equal.py new file mode 100644 index 0000000000..c58346e929 --- /dev/null +++ b/mindspore/_akg/gpu/less_equal.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""less_equal""" +import _akg.tvm +from _akg.ops.math import less_equal +from _akg.topi.generic import schedule_elemwise + +def LessEqual(x, y): + """LessEqual.""" + return less_equal.less_equal(x, y) + + +def gpu_schedule_LessEqual(outs): + """ + GPU schedule for LessEqual. + + Args: + outs (tvm.tensor.Tensor): Outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/ops/math/less_equal.py b/mindspore/_akg/ops/math/less_equal.py new file mode 100644 index 0000000000..5a566fbbca --- /dev/null +++ b/mindspore/_akg/ops/math/less_equal.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: lessequal""" +import _akg.tvm +import _akg.topi +from _akg.utils.dsl_create import produce_shapes +from _akg.utils import validation_check as vc_util + + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) +def less_equal(input1, input2): + """ + Check whether input1 lessequals to input2. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + input2 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. If input1 lessequal to input2 return True, else return False. + """ + shape1 = [x.value for x in input1.shape] + shape2 = [x.value for x in input2.shape] + vc_util.check_shape(shape1) + vc_util.check_shape(shape2) + + shape1, shape2, shape = produce_shapes(shape1, shape2) + + vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) + dtype = input1.dtype + + # get lessequal compute + t_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(1, dtype), "T") + f_value = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.const(0, dtype), "F") + + input1_bro = _akg.topi.broadcast_to(input1, shape) + input2_bro = _akg.topi.broadcast_to(input2, shape) + c_out = _akg.tvm.compute(shape, lambda *indice: _akg.tvm.expr.Select(input1_bro[indice] <= input2_bro[indice], + t_value[indice], f_value[indice]), name="C") + res = _akg.tvm.compute(shape, lambda *indice: c_out(*indice).astype("bool"), name="res") + + return res diff --git a/mindspore/ops/_op_impl/akg/gpu/lessequal.py b/mindspore/ops/_op_impl/akg/gpu/lessequal.py new file mode 100644 index 0000000000..a3e4d4dc35 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/lessequal.py @@ -0,0 +1,32 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LessEqual op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +equal_op_info = AkgRegOp("LessEqual") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.BOOL_Default) \ + .get_op_info() + + +@op_info_register(equal_op_info) +def _lessequal_akg(): + """LessEqual register""" + return From 5b5a56587ecc641f4a567b8b70779b817dc9d055 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 21 Apr 2020 16:00:44 +0800 Subject: [PATCH 049/142] gpu add akg logical_and and logical_or --- mindspore/_akg/gpu/logical_and.py | 40 ++++++++++++++++++ mindspore/_akg/gpu/logical_or.py | 40 ++++++++++++++++++ mindspore/_akg/ops/math/logical_and.py | 41 +++++++++++++++++++ mindspore/_akg/ops/math/logical_or.py | 41 +++++++++++++++++++ mindspore/ops/_op_impl/akg/gpu/logical_and.py | 29 +++++++++++++ mindspore/ops/_op_impl/akg/gpu/logical_or.py | 29 +++++++++++++ 6 files changed, 220 insertions(+) create mode 100644 mindspore/_akg/gpu/logical_and.py create mode 100644 mindspore/_akg/gpu/logical_or.py create mode 100644 mindspore/_akg/ops/math/logical_and.py create mode 100644 mindspore/_akg/ops/math/logical_or.py create mode 100644 mindspore/ops/_op_impl/akg/gpu/logical_and.py create mode 100644 mindspore/ops/_op_impl/akg/gpu/logical_or.py diff --git a/mindspore/_akg/gpu/logical_and.py b/mindspore/_akg/gpu/logical_and.py new file mode 100644 index 0000000000..6453901458 --- /dev/null +++ b/mindspore/_akg/gpu/logical_and.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""logical_and""" +import _akg.tvm +from _akg.ops.math import logical_and +from _akg.topi.generic import schedule_elemwise + +def LogicalAnd(x, y): + """LogicalAnd.""" + return logical_and.logical_and(x, y) + + +def gpu_schedule_LogicalAnd(outs): + """ + GPU schedule for LogicalAnd. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/gpu/logical_or.py b/mindspore/_akg/gpu/logical_or.py new file mode 100644 index 0000000000..1bd49bedbc --- /dev/null +++ b/mindspore/_akg/gpu/logical_or.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""logical_or""" +import _akg.tvm +from _akg.ops.math import logical_or +from _akg.topi.generic import schedule_elemwise + +def LogicalOr(x, y): + """LogicalOr.""" + return logical_or.logical_or(x, y) + + +def gpu_schedule_LogicalOr(outs): + """ + GPU schedule for LogicalOr. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/ops/math/logical_and.py b/mindspore/_akg/ops/math/logical_and.py new file mode 100644 index 0000000000..480d4e1741 --- /dev/null +++ b/mindspore/_akg/ops/math/logical_and.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: logical_and""" +import _akg.tvm +import _akg.topi +from _akg.utils import validation_check as vc_util + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) +def logical_and(input1, input2): + """ + Compute logical_and of input1 and input2. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + input2 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. LogicalAnd of input1 and input2. + """ + + vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) + + shape1 = [x.value for x in input1.shape] + shape2 = [x.value for x in input2.shape] + vc_util.check_shape(shape1) + vc_util.check_shape(shape2) + + res = _akg.topi.logical_and(input1, input2) + return res diff --git a/mindspore/_akg/ops/math/logical_or.py b/mindspore/_akg/ops/math/logical_or.py new file mode 100644 index 0000000000..8fb0b80567 --- /dev/null +++ b/mindspore/_akg/ops/math/logical_or.py @@ -0,0 +1,41 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: logical_or""" +import _akg.tvm +import _akg.topi +from _akg.utils import validation_check as vc_util + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor, _akg.tvm.tensor.Tensor) +def logical_or(input1, input2): + """ + Compute logical_or of input1 and input2. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + input2 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. LogicalOr of input1 and input2. + """ + + vc_util.elemwise_dtype_check(input1.dtype, input2.dtype) + + shape1 = [x.value for x in input1.shape] + shape2 = [x.value for x in input2.shape] + vc_util.check_shape(shape1) + vc_util.check_shape(shape2) + + res = _akg.topi.logical_or(input1, input2) + return res diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_and.py b/mindspore/ops/_op_impl/akg/gpu/logical_and.py new file mode 100644 index 0000000000..da5b696512 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/logical_and.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LogicalAnd op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +logicaland_op_info = AkgRegOp("LogicalAnd") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(logicaland_op_info) +def _logical_and_akg(): + """LogicalAnd register""" + return diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_or.py b/mindspore/ops/_op_impl/akg/gpu/logical_or.py new file mode 100644 index 0000000000..3a642511c6 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/logical_or.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LogicalOr op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +logicalor_op_info = AkgRegOp("LogicalOr") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(logicalor_op_info) +def _logical_or_akg(): + """LogicalOr register""" + return From 1c6a690a2d34d8b6d67c465bcd653eec1a710d8b Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 21 Apr 2020 16:26:28 +0800 Subject: [PATCH 050/142] gpu add akg logialnot sub --- mindspore/_akg/gpu/logical_not.py | 40 +++++++++++++++++++ mindspore/_akg/gpu/sub.py | 40 +++++++++++++++++++ mindspore/_akg/ops/math/logical_not.py | 32 +++++++++++++++ mindspore/ops/_op_impl/akg/gpu/logical_not.py | 28 +++++++++++++ mindspore/ops/_op_impl/akg/gpu/sub.py | 31 ++++++++++++++ 5 files changed, 171 insertions(+) create mode 100644 mindspore/_akg/gpu/logical_not.py create mode 100644 mindspore/_akg/gpu/sub.py create mode 100644 mindspore/_akg/ops/math/logical_not.py create mode 100644 mindspore/ops/_op_impl/akg/gpu/logical_not.py create mode 100644 mindspore/ops/_op_impl/akg/gpu/sub.py diff --git a/mindspore/_akg/gpu/logical_not.py b/mindspore/_akg/gpu/logical_not.py new file mode 100644 index 0000000000..0a38107187 --- /dev/null +++ b/mindspore/_akg/gpu/logical_not.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""logical_not""" +import _akg.tvm +from _akg.ops.math import logical_not +from _akg.topi.generic import schedule_elemwise + +def LogicalNot(x): + """LogicalNot.""" + return logical_not.logical_not(x) + + +def gpu_schedule_LogicalNot(outs): + """ + GPU schedule for LogicalNot. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/gpu/sub.py b/mindspore/_akg/gpu/sub.py new file mode 100644 index 0000000000..611e4228fd --- /dev/null +++ b/mindspore/_akg/gpu/sub.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""sub""" +import _akg.tvm +from _akg.ops.math import sub +from _akg.topi.generic import schedule_elemwise + +def Sub(x, y): + """Sub.""" + return sub.sub(x, y) + + +def gpu_schedule_Sub(outs): + """ + GPU schedule for Sub. + + Args: + outs (tvm.tensor.Tensor): outputs of compute. + + Returns: + sch (schedule.Schedule): The created schedule. + """ + device = 'cuda' + ctx = _akg.tvm.context(device, 0) + if not ctx.exist: + raise SystemError("Skip because %s is not enabled" % device) + with _akg.tvm.target.create(device): + sch = schedule_elemwise(outs) + return sch diff --git a/mindspore/_akg/ops/math/logical_not.py b/mindspore/_akg/ops/math/logical_not.py new file mode 100644 index 0000000000..9befe7e816 --- /dev/null +++ b/mindspore/_akg/ops/math/logical_not.py @@ -0,0 +1,32 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""operator dsl function: logical_not""" +import _akg.tvm +import _akg.topi +from _akg.utils import validation_check as vc_util + +@vc_util.check_input_type(_akg.tvm.tensor.Tensor) +def logical_not(input1): + """ + Compute logical_not of input1. + + Args: + input1 (tvm.tensor.Tensor): Tensor. + + Returns: + tvm.tensor.Tensor. + """ + res = _akg.topi.logical_not(input1) + return res diff --git a/mindspore/ops/_op_impl/akg/gpu/logical_not.py b/mindspore/ops/_op_impl/akg/gpu/logical_not.py new file mode 100644 index 0000000000..4b3c7bf647 --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/logical_not.py @@ -0,0 +1,28 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""LogicalNot op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +logical_not_op_info = AkgRegOp("LogicalNot") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .output(0, "output") \ + .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default) \ + .get_op_info() + +@op_info_register(logical_not_op_info) +def _logical_not_akg(): + """LogicalNot AutoDiff register""" + return diff --git a/mindspore/ops/_op_impl/akg/gpu/sub.py b/mindspore/ops/_op_impl/akg/gpu/sub.py new file mode 100644 index 0000000000..06b92fb49e --- /dev/null +++ b/mindspore/ops/_op_impl/akg/gpu/sub.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Sub op""" +from mindspore.ops.op_info_register import op_info_register, AkgRegOp, DataType + +sub_op_info = AkgRegOp("Sub") \ + .fusion_type("OPAQUE") \ + .input(0, "x") \ + .input(1, "y") \ + .output(0, "output") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .get_op_info() + +@op_info_register(sub_op_info) +def _sub_akg(): + """Sub AutoDiff register""" + return From 742395da12930efc18e3ba4b6922282f2bb9409a Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Tue, 21 Apr 2020 02:50:38 +0000 Subject: [PATCH 051/142] update piecewise_constant_lr support tuple input --- mindspore/nn/dynamic_lr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 0c5a160380..266587c5c3 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -31,8 +31,8 @@ def piecewise_constant_lr(milestone, learning_rates): y[i] = x_t for i \in [M_{t-1}, M_t) Args: - milestone (list[int]): A list of milestone. This list is a monotone increasing list. - learning_rates (list[float]): A list of learning rates. + milestone (Union[list[int], tuple[int]]): A list of milestone. This list is a monotone increasing list. + learning_rates (Union[list[float], tuple[int]]): A list of learning rates. Returns: list[float]. The size of list is :math:`M_N`. From 6ffed2262506441c536f261f3fb2596adbf37a8c Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Mon, 20 Apr 2020 14:37:18 +0000 Subject: [PATCH 052/142] fix wrong formula of polynomial_decay_lr --- mindspore/nn/dynamic_lr.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 0c5a160380..63fe795a64 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -251,11 +251,11 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e .. math:: decayed\_learning\_rate[i] = (learning\_rate - end\_learning\_rate) * - (1 - tmp\_epoch / decay\_epoch)^{power} + end\_learning\_rate + (1 - tmp\_epoch / tmp\_decay\_epoch)^{power} + end\_learning\_rate - Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch), current\_epoch=floor(\frac{i}{step\_per\_epoch})`. - If `update_decay_epoch` is true, update the value of `decay_epoch` every epoch. The formula is - :math:`decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)` + Where :math:`tmp\_epoch=min(current\_epoch, decay\_epoch),\ current\_epoch=floor(\frac{i}{step\_per\_epoch})`, and + :math:`tmp\_decay\_epoch = decay\_epoch`. If `update_decay_epoch` is true, update the value of `tmp_decay_epoch` + every epoch. The formula is :math:`tmp\_decay\_epoch = decay\_epoch * ceil(current\_epoch / decay\_epoch)` Args: learning_rate (float): The initial value of learning rate. @@ -287,9 +287,10 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e validator.check_value_type('power', power, [float], None) validator.check_value_type('update_decay_epoch', update_decay_epoch, [bool], None) + origin_decay_epoch = decay_epoch function = lambda x, y: (x, min(x, y)) if update_decay_epoch: - function = lambda x, y: (x * max(math.ceil(y / x), 1), y) + function = lambda x, y: (origin_decay_epoch * max(math.ceil(y / origin_decay_epoch), 1), y) lr = [] delta = learning_rate - end_learning_rate From 6dd72f654acb2daaeb04503dbdb6b12ee61ddb91 Mon Sep 17 00:00:00 2001 From: fary86 Date: Tue, 7 Apr 2020 15:50:48 +0800 Subject: [PATCH 053/142] Add prim name to error message for nn_ops.py --- mindspore/_checkparam.py | 46 +- mindspore/context.py | 2 +- mindspore/ops/operations/nn_ops.py | 766 ++++++++++------------- tests/ut/python/nn/test_dynamic_lr.py | 20 +- tests/ut/python/nn/test_ssim.py | 2 +- tests/ut/python/ops/test_nn_ops.py | 20 +- tests/ut/python/ops/test_nn_ops_check.py | 463 ++++++++++++++ 7 files changed, 821 insertions(+), 498 deletions(-) create mode 100755 tests/ut/python/ops/test_nn_ops_check.py diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 7b8c89351c..f0b7fa0af1 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -117,10 +117,12 @@ class Validator: """Integer value judgment.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) or isinstance(arg_value, bool) + excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, value): rel_str = Rel.get_strs(rel).format(value) msg_prefix = f'For \'{prim_name}\' the' if prim_name else "The" - raise ValueError(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got {arg_value}.') + raise excp_cls(f'{msg_prefix} `{arg_name}` should be an int and must {rel_str}, but got `{arg_value}`' + f' with type `{type(arg_value).__name__}`.') return arg_value @staticmethod @@ -137,10 +139,11 @@ class Validator: """Method for checking whether an int value is in some range.""" rel_fn = Rel.get_fns(rel) type_mismatch = not isinstance(arg_value, int) + excp_cls = TypeError if type_mismatch else ValueError if type_mismatch or not rel_fn(arg_value, lower_limit, upper_limit): rel_str = Rel.get_strs(rel).format(lower_limit, upper_limit) - raise ValueError(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' - f' but got {arg_value}.') + raise excp_cls(f'For \'{prim_name}\' the `{arg_name}` should be an int in range {rel_str},' + f' but got `{arg_value}` with type `{type(arg_value).__name__}`.') return arg_value @staticmethod @@ -192,19 +195,23 @@ class Validator: @staticmethod def check_const_input(arg_name, arg_value, prim_name): - """Check valid value.""" + """Checks valid value.""" if arg_value is None: raise ValueError(f'For \'{prim_name}\' the `{arg_name}` must be a const input, but got {arg_value}.') @staticmethod - def check_scalar_type_same(args, valid_values, prim_name): - """check whether the types of inputs are the same.""" + def check_type_same(args, valid_values, prim_name): + """Checks whether the types of inputs are the same.""" def _check_tensor_type(arg): arg_key, arg_val = arg elem_type = arg_val + type_names = [] if not elem_type in valid_values: - raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {valid_values},' - f' but `{arg_key}` is {elem_type}.') + for t in valid_values: + type_names.append(str(t)) + types_info = '[' + ", ".join(type_names) + ']' + raise TypeError(f'For \'{prim_name}\' type of `{arg_key}` should be in {types_info},' + f' but got {elem_type}.') return (arg_key, elem_type) def _check_types_same(arg1, arg2): @@ -212,7 +219,7 @@ class Validator: arg2_name, arg2_type = arg2 if arg1_type != arg2_type: raise TypeError(f'For \'{prim_name}\' type of `{arg2_name}` should be same as `{arg1_name}`,' - f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') + f' but `{arg1_name}` with type {arg1_type} and `{arg2_name}` with type {arg2_type}.') return arg1 elem_types = map(_check_tensor_type, args.items()) @@ -221,25 +228,8 @@ class Validator: @staticmethod def check_tensor_type_same(args, valid_values, prim_name): """Checks whether the element types of input tensors are the same.""" - def _check_tensor_type(arg): - arg_key, arg_val = arg - Validator.check_subclass(arg_key, arg_val, mstype.tensor, prim_name) - elem_type = arg_val.element_type() - if not elem_type in valid_values: - raise TypeError(f'For \'{prim_name}\' element type of `{arg_key}` should be in {valid_values},' - f' but element type of `{arg_key}` is {elem_type}.') - return (arg_key, elem_type) - - def _check_types_same(arg1, arg2): - arg1_name, arg1_type = arg1 - arg2_name, arg2_type = arg2 - if arg1_type != arg2_type: - raise TypeError(f'For \'{prim_name}\' element type of `{arg2_name}` should be same as `{arg1_name}`,' - f' but `{arg1_name}` is {arg1_type} and `{arg2_name}` is {arg2_type}.') - return arg1 - - elem_types = map(_check_tensor_type, args.items()) - reduce(_check_types_same, elem_types) + tensor_types = [mstype.tensor_type(t) for t in valid_values] + Validator.check_type_same(args, tensor_types, prim_name) @staticmethod def check_scalar_or_tensor_type_same(args, valid_values, prim_name, allow_mix=False): diff --git a/mindspore/context.py b/mindspore/context.py index f6fe8705fd..159522a87a 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -34,7 +34,7 @@ GRAPH_MODE = 0 PYNATIVE_MODE = 1 -def _make_directory(path: str): +def _make_directory(path): """Make directory.""" real_path = None if path is None or not isinstance(path, str) or path.strip() == "": diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 1fb65e3b76..5fd2c24a6e 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -24,12 +24,39 @@ import numpy as np from ... import context from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_bool, check_int_positive +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register +def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): + """ + Checks whether an argument is a positive int or tuple with 2 or 4(when allow_four is True) positive int elements. + """ + def _raise_message(): + raise ValueError(f"For '{prim_name}' attr '{arg_name}' should be an positive int number or a tuple of two " + f"{'or four ' if allow_four else ''}positive int numbers, but got {arg_value}") + def _get_return_value(): + if isinstance(arg_value, int): + ret = (1, 1, arg_value, arg_value) if ret_four else (arg_value, arg_value) + elif len(arg_value) == 2: + ret = (1, 1, arg_value[0], arg_value[1]) if ret_four else arg_value + elif len(arg_value) == 4: + if not allow_four: + _raise_message() + ret = arg_value if ret_four else (arg_value[2], arg_value[3]) + else: + _raise_message() + return ret + validator.check_value_type(arg_name, arg_value, (int, tuple), prim_name) + ret_value = _get_return_value() + for item in ret_value: + if isinstance(item, int) and item > 0: + continue + _raise_message() + return ret_value + class Flatten(PrimitiveWithInfer): r""" Flattens a tensor without changing its batch size on the 0-th axis. @@ -53,12 +80,12 @@ class Flatten(PrimitiveWithInfer): pass def infer_shape(self, input_x): - validator.check('input_x rank', len(input_x), '', 1, Rel.GE) + validator.check_integer('input_x rank', len(input_x), 1, Rel.GE, self.name) prod = 1 if len(input_x) == 1 else reduce(operator.mul, input_x[1:]) return input_x[0], prod def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) + validator.check_subclass("input_x", input_x, mstype.tensor, self.name) return input_x @@ -88,21 +115,21 @@ class Softmax(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_type("axis", axis, [int, tuple]) + validator.check_value_type("axis", axis, [int, tuple], self.name) if isinstance(axis, int): self.add_prim_attr('axis', (axis,)) for item in self.axis: - validator.check_type("item of axis", item, [int]) + validator.check_value_type("item of axis", item, [int], self.name) def infer_shape(self, logits): - validator.check_shape_length("axis shape", len(self.axis), 1, Rel.GE) + validator.check_integer("length of axis", len(self.axis), 1, Rel.GE, self.name) rank = len(logits) for axis_v in self.axis: - validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT) + validator.check_int_range("axis", axis_v, -rank, rank, Rel.INC_LEFT, self.name) return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -131,15 +158,15 @@ class LogSoftmax(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def infer_shape(self, logits): rank = len(logits) - validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH) + validator.check_int_range('axis', self.axis, -rank, rank, Rel.INC_LEFT, self.name) return logits def infer_dtype(self, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -171,8 +198,7 @@ class ReLU(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, mstype.number_type) + validator.check_tensor_type_same({'input_x': input_x}, mstype.number_type, self.name) return input_x @@ -203,8 +229,7 @@ class ReLU6(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({'input_x': input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -233,14 +258,13 @@ class Elu(PrimitiveWithInfer): @prim_attr_register def __init__(self, alpha=1.0): """Init Elu""" - validator.check_type("alpha", alpha, [float]) + validator.check_value_type("alpha", alpha, [float], self.name) def infer_shape(self, input_x): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x_dtype", input_x, mstype.float_type) + validator.check_tensor_type_same({'input_x': input_x}, mstype.float_type, self.name) return input_x @@ -272,8 +296,7 @@ class HSwish(PrimitiveWithInfer): return xshape def infer_dtype(self, x_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -305,8 +328,7 @@ class Sigmoid(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -339,8 +361,7 @@ class HSigmoid(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_dtype): - validator.check_subclass("x_dtype", x_dtype, mstype.tensor) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -370,7 +391,7 @@ class Tanh(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) + validator.check_subclass("input_x", input_x, mstype.tensor, self.name) return input_x @@ -418,9 +439,9 @@ class FusedBatchNorm(Primitive): def __init__(self, mode=0, epsilon=1e-5, momentum=0.1): self.init_prim_io_names(inputs=['x', 'scale', 'b', 'mean', 'variance'], outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance']) - self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) - self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH) + self.mode = validator.check_integer('mode', mode, [0, 1], Rel.IN, self.name) + self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) + self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) class BatchNorm(PrimitiveWithInfer): @@ -464,32 +485,34 @@ class BatchNorm(PrimitiveWithInfer): @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5): - self.is_training = validator.check_type('is_training', is_training, (bool,)) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) + validator.check_value_type('is_training', is_training, (bool,), self.name) + validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2', 'reserve_space_3']) def infer_shape(self, input_x, scale, bias, mean, variance): - validator.check("BatchNorm scale shape length", len(scale), "1", 1, Rel.EQ) - validator.check("BatchNorm scale shape", scale, "BatchNorm bias shape", bias) - validator.check("BatchNorm scale shape", scale[0], "BatchNorm input_x shape[1]", input_x[1]) + validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) + validator.check("scale shape", scale, "bias shape", bias, Rel.EQ, self.name) + validator.check("scale shape[0]", scale[0], "input_x shape[1]", input_x[1], Rel.EQ, self.name) if not self.is_training: - validator.check("BatchNorm mean shape length", len(mean), "1", 1, Rel.EQ) - validator.check("BatchNorm mean shape", mean, "BatchNorm variance shape", variance) - validator.check("BatchNorm mean shape", mean, "BatchNorm scale shape", scale) + validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) + validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) + validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) return (input_x, scale, scale, scale, scale, scale) def infer_dtype(self, input_x, scale, bias, mean, variance): - args = {"BatchNorm scale type": scale, "BatchNorm bias type": bias} - args_moving = {"BatchNorm mean type": mean, "BatchNorm variance type": variance} - validator.check_typename("input_x", input_x, [mstype.float32, mstype.float16]) - validator.check_type_same(args, [mstype.float32, mstype.float16]) + validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) + args = {"scale": scale, "bias": bias} + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) + args_moving = {"mean": mean, "variance": variance} if self.is_training: - validator.check_type_same(args_moving, [mstype.float32, mstype.float16, None]) + valid_types = [mstype.tensor_type(mstype.float16), mstype.tensor_type(mstype.float32), None] + validator.check_type_same(args_moving, valid_types, self.name) else: - validator.check_type_same(args_moving, [mstype.float32, mstype.float16]) + args_moving = {"mean": mean, "variance": variance} + validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) return (input_x, scale, bias, input_x, input_x, input_x) @@ -559,53 +582,28 @@ class Conv2D(PrimitiveWithInfer): group=1): """init Conv2D""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ - (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {stride}") + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) if self.pad_mode == 'pad': - validator.check_integer('pad', self.pad, 0, Rel.GE) + validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.group = validator.check_integer('group', group, 0, Rel.GT) + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - validator.check_param_equal("x_shape[1]", x_shape[1] // self.group, "w_shape[1]", w_shape[1]) - validator.check_param_equal('out_channel', self.out_channel, 'w_shape[0]', w_shape[0]) - validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) + validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -647,10 +645,9 @@ class Conv2D(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} - validator.check_subclass('input', x_dtype, mstype.tensor) - validator.check_subclass('weight', w_dtype, mstype.tensor) - validator.check_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) + args = {'x': x_dtype, 'w': w_dtype} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) return x_dtype @@ -697,49 +694,25 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): group=1): """init DepthwiseConv2dNative""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ - (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {stride}") + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (dilation, dilation) - if len(self.dilation) != 2 or (not isinstance(self.dilation[0], int)) or \ - (not isinstance(self.dilation[1], int)) or \ - self.dilation[0] < 1 or self.dilation[1] < 1: - raise ValueError(f"The \'dilation\' of \'DepthwiseConv2dNative\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name) self.add_prim_attr('dilation', (1, 1, self.dilation[0], self.dilation[1])) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - if pad_mode not in ("same", "valid", "pad"): - raise ValueError(f"Attr pad_mode of DepthwiseConv2dNative Op not passed" - f"{pad_mode} not in valid, same, pad.") - self.pad_mode = pad_mode - self.mode = validator.check_integer("mode", mode, 3, Rel.EQ) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + self.mode = validator.check_integer("mode", mode, 3, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT) - self.group = validator.check_integer("group", group, 0, Rel.GT) - self.pad = pad + self.channel_multiplier = validator.check_integer("channel_multiplier", channel_multiplier, 0, Rel.GT, + self.name) + self.group = validator.check_integer("group", group, 0, Rel.GT, self.name) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - validator.check_param_equal("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1]) - validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) + validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check("x_shape[1]", x_shape[1], "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -772,9 +745,6 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): / stride_w h_out = math.floor(h_out) w_out = math.floor(w_out) - else: - raise ValueError(f"Attr pad_mode of DepthwiseConv2dNative Op not passed" - "{pad_mode} not in valid, same, pad.") self.pad_list = (pad_top, pad_bottom, pad_left, pad_right) self.add_prim_attr('pads', self.pad_list) @@ -784,8 +754,8 @@ class DepthwiseConv2dNative(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} - validator.check_type_same(args, mstype.number_type) + args = {'x': x_dtype, 'w': w_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -805,48 +775,26 @@ class _Pool(PrimitiveWithInfer): @prim_attr_register def __init__(self, ksize=1, strides=1, padding="valid"): self.init_prim_io_names(inputs=['x'], outputs=['output']) - validator.check_type('ksize', ksize, [int, tuple]) - validator.check_type('strides', strides, [int, tuple]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + validator.check_value_type('ksize', ksize, [int, tuple], self.name) + validator.check_value_type('strides', strides, [int, tuple], self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) self.is_maxpoolwithargmax = (self.name == "MaxPoolWithArgmax") if not self.is_maxpoolwithargmax: self.add_prim_attr('data_format', "NCHW") - if isinstance(ksize, int): - validator.check_integer("ksize", ksize, 1, Rel.GE) - self.ksize = (1, 1, ksize, ksize) - else: - if (len(ksize) != 2 or - (not isinstance(ksize[0], int)) or - (not isinstance(ksize[1], int)) or - ksize[0] <= 0 or - ksize[1] <= 0): - raise ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number or " - f"a tuple of two positive int numbers, but got {ksize}") - self.ksize = (1, 1, ksize[0], ksize[1]) + self.ksize = _check_positive_int_or_tuple("ksize", ksize, self.name, allow_four=False, ret_four=True) if self.is_maxpoolwithargmax: self.ksize = (1, self.ksize[-2], self.ksize[-1], 1) self.add_prim_attr("ksize", self.ksize) - if isinstance(strides, int): - validator.check_integer("strides", strides, 1, Rel.GE) - self.strides = (1, 1, strides, strides) - else: - if (len(strides) != 2 or - (not isinstance(strides[0], int)) or - (not isinstance(strides[1], int)) or - strides[0] <= 0 or - strides[1] <= 0): - raise ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number or " - f"a tuple of two positive int numbers, but got {strides}") - self.strides = (1, 1, strides[0], strides[1]) + self.strides = _check_positive_int_or_tuple("strides", strides, self.name, allow_four=False, ret_four=True) if self.is_maxpoolwithargmax: self.strides = (1, self.strides[-2], self.strides[-1], 1) self.add_prim_attr("strides", self.strides) def infer_shape(self, x_shape): - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) batch, channel, input_h, input_w = x_shape if self.is_maxpoolwithargmax: _, kernel_h, kernel_w, _ = self.ksize @@ -861,18 +809,16 @@ class _Pool(PrimitiveWithInfer): elif self.padding == "SAME": out_h = math.ceil(input_h / stride_h) out_w = math.ceil(input_w / stride_w) - else: - raise ValueError(f"The padding of operator {self.name} should be a str and must be 'SAME' or 'VALID', " - f"but got {self.padding}.") out_shape = [batch, channel, out_h, out_w] for shape_value in out_shape: if shape_value <= 0: - raise ValueError("The kernel size is not valid please check it if is larger than data's shape size.") + raise ValueError(f"For '{self.name}' The kernel size is not valid, " + f"please check it if is larger than data's shape size.") return out_shape def infer_dtype(self, x_dtype): - validator.check_subclass("input", x_dtype, mstype.tensor) + validator.check_subclass("input", x_dtype, mstype.tensor, self.name) return x_dtype @@ -987,7 +933,7 @@ class MaxPoolWithArgmax(_Pool): def infer_dtype(self, x_dtype): out_dtype = x_dtype - validator.check_typename("x_type", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) argmax_dtype = mstype.uint16 return out_dtype, argmax_dtype @@ -1071,56 +1017,33 @@ class Conv2DBackpropInput(PrimitiveWithInfer): group=1): """init Conv2DBackpropInput""" self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - elif isinstance(stride, tuple) and len(stride) == 4: - self.stride = (stride[2], stride[3]) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {stride}") + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.add_prim_attr('stride', self.stride) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) - self.group = validator.check_integer('group', group, 0, Rel.GT) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.add_prim_attr('data_format', "NCHW") if pad_list: - self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list) + for x in pad_list: + validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name) + self.pad_list = pad_list def __infer__(self, doutput, w, x_size): x_size_v = x_size['value'] - validator.check_type('x_size', x_size_v, [tuple]) + validator.check_value_type('x_size', x_size_v, [tuple], self.name) for i, dim_len in enumerate(x_size_v): - validator.check_type("x_size[%d]" % i, dim_len, [int]) - validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype']) + validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) + args = {'doutput': doutput['dtype'], 'w': w['dtype']} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) # infer shape dout_shape = doutput['shape'] @@ -1173,16 +1096,15 @@ class BiasAdd(PrimitiveWithInfer): self.add_prim_attr('data_format', 'NCHW') def infer_shape(self, x_shape, b_shape): - if len(b_shape) != 1 or len(x_shape) < 2 or b_shape[0] != x_shape[1]: - raise ValueError("Input_x and bias shapes do not match", - "(require: rank of input_x must be at least 2, rank of bias must be 1, " - "input_x.dim[1] must equal bias.dim[0])," - " but got input_x shape {}, bias shape {}.".format(x_shape, b_shape)) + validator.check_integer("x rank", len(x_shape), 2, Rel.GE, self.name) + validator.check_integer("bias rank", len(b_shape), 1, Rel.EQ, self.name) + validator.check("b_shape[0]", b_shape[0], "x_shape[1]", x_shape[1], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, b_type): - args = {"input_x type": x_type, "bias type": b_type} - validator.check_type_same(args, (mstype.float16, mstype.float32, mstype.int8, mstype.int32)) + args = {"input_x": x_type, "bias": b_type} + valid_types = (mstype.int8, mstype.int32, mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) return x_type @@ -1215,22 +1137,21 @@ class TopK(PrimitiveWithInfer): @prim_attr_register def __init__(self, sorted=False): - validator.check_type("sorted", sorted, [bool]) + validator.check_value_type("sorted", sorted, [bool], self.name) self.init_prim_io_names(inputs=['input', 'k'], outputs=['values', 'indices']) def __infer__(self, input_x, k): + x_dtype = input_x['dtype'] + valid_types = (mstype.int32, mstype.float16, mstype.float32) + validator.check_tensor_type_same({'x': x_dtype}, valid_types, self.name) + k_v = k['value'] + validator.check_value_type('k', k_v, (int,), self.name) x_shape = list(input_x['shape']) ndim = len(x_shape) - 1 - k_v = k['value'] x_shape[ndim] = k_v - input_dtype = input_x['dtype'] - validator.check_typename("TopK input_dtype", - input_dtype, (mstype.float16, mstype.float32, mstype.int32)) - if not isinstance(k_v, int): - raise ValueError('The k must int.', k) return {'shape': (x_shape, x_shape), - 'dtype': (input_dtype, mstype.int32), + 'dtype': (x_dtype, mstype.int32), 'value': None} @@ -1260,16 +1181,14 @@ class SoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): pass def infer_shape(self, logits_shape, labels_shape): - validator.check_param_equal("SoftmaxCrossEntropyWithLogits logits_shape", logits_shape, - "SoftmaxCrossEntropyWithLogits labels_shape", labels_shape) + validator.check("logits_shape", logits_shape, "labels_shape", labels_shape, Rel.EQ, self.name) loss_shape = [logits_shape[0]] dlogits_shape = logits_shape return (loss_shape, dlogits_shape) def infer_dtype(self, logits_type, labels_type): - args = {"SoftmaxCrossEntropyWithLogits logits_type": logits_type, - "SoftmaxCrossEntropyWithLogits labels_type": labels_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + args = {"logits": logits_type, "labels": labels_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return (logits_type, logits_type) @@ -1308,18 +1227,15 @@ class SparseSoftmaxCrossEntropyWithLogits(PrimitiveWithInfer): self.add_prim_attr('sens', 1.0) def infer_shape(self, logits_shape, labels_shape): - validator.check_param_equal("SparseSoftmaxCrossEntropyWithLogits logits_shape", logits_shape[0], - "SparseSoftmaxCrossEntropyWithLogits labels_shape", labels_shape[0]) + validator.check("logits_shape[0]", logits_shape[0], "labels_shape[0]", labels_shape[0], Rel.EQ, self.name) loss_shape = [] if self.is_grad: return logits_shape return loss_shape def infer_dtype(self, logits_type, labels_type): - validator.check_typename("SparseSoftmaxCrossEntropyWithLogits logits_type", - logits_type, (mstype.float16, mstype.float32)) - validator.check_typename("SparseSoftmaxCrossEntropyWithLogits labels_type", - labels_type, (mstype.int32, mstype.int64)) + validator.check_tensor_type_same({"logits": logits_type}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"labels": labels_type}, (mstype.int32, mstype.int64), self.name) return logits_type @@ -1364,14 +1280,13 @@ class ApplyMomentum(PrimitiveWithInfer): return v_shape def infer_dtype(self, v_dtype, a_dtype, l_dtype, g_dtype, m_dtype): + valid_types = [mstype.float16, mstype.float32, mstype.float64] if v_dtype != mstype.type_refkey and a_dtype != mstype.type_refkey: - validator.check_subclass("v_dtype", v_dtype, mstype.tensor) - validator.check_subclass("a_dtype", a_dtype, mstype.tensor) - validator.check_typename("v_dtype", v_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("a_dtype", a_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("l_dtype", l_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("g_dtype", g_dtype, [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("m_dtype", m_dtype, [mstype.float16, mstype.float32, mstype.float64]) + validator.check_tensor_type_same({"v": v_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"a": a_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l_dtype": l_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"g_dtype": g_dtype}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"m_dtype": m_dtype}, valid_types, self.name) return g_dtype @@ -1403,17 +1318,17 @@ class SmoothL1Loss(PrimitiveWithInfer): @prim_attr_register def __init__(self, sigma=1.0): - validator.check_type('sigma', sigma, [float]) - validator.check('sigma', sigma, '', 0, Rel.GT) + validator.check_value_type('sigma', sigma, [float], self.name) + validator.check('sigma', sigma, '', 0, Rel.GT, self.name) self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) def infer_shape(self, prediction, target): - validator.check_param_equal('prediction shape', prediction, 'target shape', target) + validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) return prediction def infer_dtype(self, prediction, target): args = {"prediction": prediction, "target": target} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return prediction @@ -1446,29 +1361,30 @@ class SGD(PrimitiveWithInfer): @prim_attr_register def __init__(self, dampening=0.0, weight_decay=0.0, nesterov=False): - validator.check_type("nesterov", nesterov, [bool]) + validator.check_value_type("nesterov", nesterov, [bool], self.name) self.init_prim_io_names(inputs=['parameters', 'gradient', 'learning_rate', 'accum', 'momentum', 'stat'], outputs=['output']) def infer_shape(self, parameters_shape, gradient_shape, learning_rate_shape, accum_shape, momentum_shape, stat_shape): - validator.check(f'parameters shape {parameters_shape}', len(parameters_shape), '', 0, Rel.GT) - validator.check(f'gradient shape {gradient_shape}', len(gradient_shape), '', 0, Rel.GE) - validator.check(f'learning rate shape {learning_rate_shape}', len(learning_rate_shape), '', 0, Rel.GE) - validator.check(f'accumulation shape {accum_shape}', len(accum_shape), '', 0, Rel.GT) - validator.check(f'momentum shape {momentum_shape}', len(momentum_shape), '', 0, Rel.GE) - validator.check(f'stat shape {stat_shape}', len(stat_shape), '', 0, Rel.GE) - validator.check("gradient shape", gradient_shape, "stat shape", stat_shape) + validator.check_integer(f'parameters rank', len(parameters_shape), 0, Rel.GT, self.name) + validator.check_integer(f'gradient rank', len(gradient_shape), 0, Rel.GE, self.name) + validator.check_integer(f'learning rate rank', len(learning_rate_shape), 0, Rel.GE, self.name) + validator.check_integer(f'accumulation rank', len(accum_shape), 0, Rel.GT, self.name) + validator.check_integer(f'momentum rank', len(momentum_shape), 0, Rel.GE, self.name) + validator.check_integer(f'stat rank', len(stat_shape), 0, Rel.GE, self.name) + validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name) return parameters_shape def infer_dtype(self, parameters_dtype, gradient_dtype, learning_rate_dtype, accum_dtype, momentum_dtype, stat_dtype): - validator.check_typename("parameters_dtype", parameters_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("gradient_dtype", gradient_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("learning_rate_dtype", learning_rate_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("accum_dtype", accum_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("momentum_dtype", momentum_dtype, [mstype.float16, mstype.float32]) - validator.check_typename("stat_dtype", stat_dtype, [mstype.float16, mstype.float32]) + valid_types = [mstype.float16, mstype.float32] + validator.check_tensor_type_same({"parameters": parameters_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"gradient": gradient_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"learning_rate": learning_rate_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"accum": accum_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"momentum": momentum_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"stat": stat_dtype}, valid_types, self.name) return parameters_dtype class ApplyRMSProp(PrimitiveWithInfer): @@ -1514,28 +1430,23 @@ class ApplyRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): - validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) - validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, decay_dtype, momentum_dtype, epsilon_dtype): - validator.check_subclass("var_dtype", var_dtype, mstype.tensor) - validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) - validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) - validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) - args = {"var_dtype": var_dtype, "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, - "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) - - args = {"learning_rate_dtype": learning_rate_dtype, "decay_dtype": decay_dtype, - 'momentum_dtype': momentum_dtype, "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"var": var_dtype, "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + + args = {"learning_rate": learning_rate_dtype, "decay": decay_dtype, + 'momentum': momentum_dtype, "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) return var_dtype @@ -1587,30 +1498,25 @@ class ApplyCenteredRMSProp(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False): - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, mean_gradient_shape, mean_square_shape, moment_shape, grad_shape, learning_rate_shape, decay_shape, momentum_shape, epsilon_shape): - validator.check_param_equal("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape) - validator.check_param_equal("var_shape", var_shape, "mean_square_shape", mean_square_shape) - validator.check_param_equal("var_shape", var_shape, "moment_shape", moment_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "mean_gradient_shape", mean_gradient_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "mean_square_shape", mean_square_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "moment_shape", moment_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_dtype, mean_gradient_dtype, mean_square_dtype, moment_dtype, grad_dtype, learning_rate_dtype, rho_dtype, momentum_dtype, epsilon_dtype): - validator.check_subclass("var_dtype", var_dtype, mstype.tensor) - validator.check_subclass("mean_gradient_dtype", mean_gradient_dtype, mstype.tensor) - validator.check_subclass("mean_square_dtype", mean_square_dtype, mstype.tensor) - validator.check_subclass("moment_dtype", moment_dtype, mstype.tensor) - validator.check_subclass("grad_dtype", moment_dtype, mstype.tensor) - args = {"var_dtype": var_dtype, "mean_gradient_dtype": mean_gradient_dtype, - "mean_square_dtype": mean_square_dtype, "moment_dtype": moment_dtype, "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) - - args = {"learning_rate_dtype": learning_rate_dtype, "rho_dtype": rho_dtype, 'momentum_dtype': momentum_dtype, - "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"var": var_dtype, "mean_gradient": mean_gradient_dtype, + "mean_square": mean_square_dtype, "moment": moment_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) + + args = {"learning_rate": learning_rate_dtype, "rho": rho_dtype, 'momentum': momentum_dtype, + "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name) return var_dtype @@ -1651,8 +1557,8 @@ class LayerNorm(Primitive): @prim_attr_register def __init__(self, begin_norm_axis=1, begin_params_axis=1): - validator.check_type('begin_norm_axis', begin_norm_axis, [int]) - validator.check_type('begin_params_axis', begin_params_axis, [int]) + validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) + validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) class L2Normalize(PrimitiveWithInfer): @@ -1679,16 +1585,16 @@ class L2Normalize(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0, epsilon=1e-4): - validator.check_type('axis', axis, [int]) - validator.check_type('epsilon', epsilon, [int, float]) + validator.check_value_type('axis', axis, [int], self.name) + validator.check_value_type('epsilon', epsilon, [int, float], self.name) def infer_shape(self, input_x): dim = len(input_x) - validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT) + validator.check_int_range('axis value', self.axis, -dim, dim, Rel.INC_LEFT, self.name) return input_x def infer_dtype(self, input_x): - validator.check_subclass("x", input_x, mstype.tensor) + validator.check_subclass("x", input_x, mstype.tensor, self.name) return input_x @@ -1718,8 +1624,8 @@ class DropoutGenMask(Primitive): @prim_attr_register def __init__(self, Seed0=0, Seed1=0): self.init_prim_io_names(inputs=['shape', 'keep_prob'], outputs=['output']) - validator.check_type("Seed0", Seed0, [int]) - validator.check_type("Seed1", Seed1, [int]) + validator.check_value_type("Seed0", Seed0, [int], self.name) + validator.check_value_type("Seed1", Seed1, [int], self.name) class DropoutDoMask(PrimitiveWithInfer): @@ -1759,7 +1665,7 @@ class DropoutDoMask(PrimitiveWithInfer): input_x_shape = input_x['shape'] mask_shape = mask['shape'] keep_prob_shape = keep_prob['shape'] - validator.check("keep_prob's dim", len(keep_prob_shape), '0(scalar)', 0) + validator.check("keep_prob's dim", len(keep_prob_shape), '0(scalar)', 0, Rel.EQ, self.name) size_x = reduce(lambda x, y: x * y, input_x_shape) if len(mask_shape) != 1: raise ValueError("DropoutDoMask mask shape should be 1-dimension.") @@ -1768,13 +1674,13 @@ class DropoutDoMask(PrimitiveWithInfer): raise ValueError(f"DropoutDoMask y mask do not math input input_x shape:" "{input_x_shape}, mask shape: {mask_shape}.") - validator.check_typename("input_x type", input_x['dtype'], [mstype.float32, mstype.float16, mstype.int32]) - validator.check_typename("input_mask type", mask['dtype'], [mstype.uint8]) + validator.check_tensor_type_same({"input_x": input_x['dtype']}, [mstype.float32, mstype.float16, mstype.int32], + self.name) + validator.check_tensor_type_same({"input_mask": mask['dtype']}, [mstype.uint8], self.name) keep_prob_v = keep_prob['value'] if keep_prob_v is not None: - validator.check_const_input('keep_prob', keep_prob_v) - validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH) + validator.check_number_range('keep_prob', keep_prob_v.asnumpy(), 0, 1, Rel.INC_BOTH, self.name) out = {'shape': input_x_shape, 'dtype': input_x['dtype'], @@ -1858,23 +1764,20 @@ class OneHot(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): self.init_prim_io_names(inputs=['indices', 'depth', 'on_value', 'off_value'], outputs=['output']) - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def __infer__(self, indices, depth, on_value, off_value): # check type - validator.check_subclass("indices", indices['dtype'], mstype.tensor) - validator.check_typename("indices", indices['dtype'], (mstype.int32,)) - validator.check_typename("depth", depth['dtype'], mstype.int_type) - validator.check_subclass("on_value", on_value['dtype'], mstype.tensor) - validator.check_subclass("off_value", off_value['dtype'], mstype.tensor) - args = {"on_value dtype": on_value['dtype'], "off_value dtype": off_value['dtype']} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"indices": indices['dtype']}, (mstype.int32,), self.name) + validator.check_type_name("depth", depth['dtype'], mstype.int_type, self.name) + args = {"on_value": on_value['dtype'], "off_value": off_value['dtype']} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) # check shape indices_shp = indices['shape'] - validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH) + validator.check_int_range("axis", self.axis, -1, len(indices_shp), Rel.INC_BOTH, self.name) depth_val = depth['value'] - validator.check_integer("depth", depth_val, 0, Rel.GE) + validator.check_integer("depth", depth_val, 0, Rel.GE, self.name) # create new dimension at end if self.axis is -1 indices_shp.insert(self.axis, depth_val) if self.axis >= 0 else indices_shp.append(depth_val) @@ -1919,8 +1822,7 @@ class Gelu(PrimitiveWithInfer): return input_x def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x", input_x, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.float16, mstype.float32), self.name) return input_x @@ -1953,10 +1855,10 @@ class GetNext(PrimitiveWithInfer): @prim_attr_register def __init__(self, types, shapes, output_num, shared_name): - validator.check_type("types", types, [list, tuple]) - validator.check_type("shapes", shapes, [list, tuple]) - validator.check("types length", len(types), "shapes length", len(shapes)) - validator.check_type("output_num", output_num, [int]) + validator.check_value_type("types", types, [list, tuple], self.name) + validator.check_value_type("shapes", shapes, [list, tuple], self.name) + validator.check("types length", len(types), "shapes length", len(shapes), Rel.EQ, self.name) + validator.check_value_type("output_num", output_num, [int], self.name) def infer_shape(self): return tuple(self.shapes) @@ -1997,24 +1899,22 @@ class PReLU(PrimitiveWithInfer): weight_dim = len(weight_shape) if weight_dim != 1: - raise ValueError(f'weight_dim must be 1, while weight_dim is {weight_dim}.') + raise ValueError(f'For \'{self.name}\' weight_dim must be 1, while weight_dim is {weight_dim}.') if input_x_dim == 1 and weight_shape[0] != 1: - raise ValueError(f'when input_x_dim is 1, weight_shape[0] must be 1, ' + raise ValueError(f'For \'{self.name}\' when input_x_dim is 1, weight_shape[0] must be 1, ' f'while weight_shape[0] is {weight_shape[0]}.') if input_x_dim != 1 and weight_shape[0] != input_x_shape[1] and weight_shape[0] != 1: - raise ValueError(f'channel of input_x and weight must be matched,' + raise ValueError(f'For \'{self.name}\' channel of input_x and weight must be matched,' f' while channel of input_x is {input_x_shape[1]},' f' weight_shape[0] is {weight_shape[0]}.') return input_x_shape def infer_dtype(self, input_x_dtype, weight_dtype): - validator.check_subclass("input_x_dtype", input_x_dtype, mstype.tensor) - validator.check_subclass("weight_dtype", weight_dtype, mstype.tensor) - validator.check_typename("input_x_dtype", input_x_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("weight_dtype", weight_dtype, (mstype.float16, mstype.float32)) + args = {"input_x": input_x_dtype, "weight": weight_dtype} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return input_x_dtype @@ -2027,13 +1927,13 @@ class LSTM(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer("input_size", input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer("hidden_size", hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer("num_layers", num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type("has_bias", has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type("bidirectional", bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -2042,19 +1942,16 @@ class LSTM(PrimitiveWithInfer): def infer_shape(self, x_shape, h_shape, c_shape, w_shape): # (batch, seq, feature) - validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 3, Rel.EQ, self.name) # h and c should be same shape - validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ) - validator.check_integer("h_shape", len(h_shape), len(c_shape), Rel.EQ) - validator.check_integer("h_shape", h_shape[0], c_shape[0], Rel.EQ) - validator.check_integer("h_shape", h_shape[1], c_shape[1], Rel.EQ) - validator.check_integer("h_shape", h_shape[2], c_shape[2], Rel.EQ) + validator.check_integer("h rank", len(h_shape), 3, Rel.EQ, self.name) + validator.check("h_shape", h_shape, "c_shape", c_shape, Rel.EQ, self.name) # (num_layers * num_directions, batch, hidden_size) - validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ) - validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ) - validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ) + validator.check_integer("h[0]", h_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) + validator.check_integer("h[1]", h_shape[1], x_shape[1], Rel.EQ, self.name) + validator.check_integer("h[2]", h_shape[2], self.hidden_size, Rel.EQ, self.name) y_shape = (x_shape[0], x_shape[1], self.hidden_size * self.num_directions) @@ -2064,13 +1961,8 @@ class LSTM(PrimitiveWithInfer): return (y_shape, h_shape, c_shape, reserved_shape, state_shape) def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): - validator.check_typename("x_dtype", x_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("h_dtype", h_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("c_dtype", c_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("w_dtype", w_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("datatype", x_dtype, (h_dtype.element_type(),)) - validator.check_typename("datatype", x_dtype, (c_dtype.element_type(),)) - validator.check_typename("datatype", x_dtype, (w_dtype.element_type(),)) + args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} + validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) @@ -2101,12 +1993,12 @@ class SigmoidCrossEntropyWithLogits(PrimitiveWithInfer): self.init_prim_io_names(inputs=['predict', 'target'], outputs=['loss']) def infer_shape(self, x_shape, y_shape): - validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape) + validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_dtype, y_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x_dtype @@ -2150,7 +2042,7 @@ class Pad(PrimitiveWithInfer): def infer_shape(self, x): paddings = np.array(self.paddings) - validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ) + validator.check_integer('paddings.shape', paddings.size, len(x) * 2, Rel.EQ, self.name) if not np.all(paddings >= 0): raise ValueError('All elements of paddings must be >= 0.') y_shape = () @@ -2159,7 +2051,7 @@ class Pad(PrimitiveWithInfer): return y_shape def infer_dtype(self, x): - validator.check_subclass("input_x", x, mstype.tensor) + validator.check_subclass("input_x", x, mstype.tensor, self.name) return x @@ -2210,16 +2102,16 @@ class MirrorPad(PrimitiveWithInfer): @prim_attr_register def __init__(self, mode='REFLECT'): """Init Pad""" - validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) self.mode = mode def __infer__(self, input_x, paddings): - validator.check_subclass("input_x", input_x['dtype'], mstype.tensor) - validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) + validator.check_subclass("input_x", input_x['dtype'], mstype.tensor, self.name) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name) x_shape = list(input_x['shape']) paddings_value = paddings['value'].asnumpy() paddings_size = paddings_value.size - validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ) + validator.check_integer('paddings.shape', paddings_size, len(x_shape) * 2, Rel.EQ, self.name) if not np.all(paddings_size >= 0): raise ValueError('All elements of paddings must be >= 0.') y_shape = () @@ -2270,10 +2162,10 @@ class ROIAlign(PrimitiveWithInfer): @prim_attr_register def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2): """init ROIAlign""" - validator.check_type("pooled_height", pooled_height, [int]) - validator.check_type("pooled_width", pooled_width, [int]) - validator.check_type("spatial_scale", spatial_scale, [float]) - validator.check_type("sample_num", sample_num, [int]) + validator.check_value_type("pooled_height", pooled_height, [int], self.name) + validator.check_value_type("pooled_width", pooled_width, [int], self.name) + validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) + validator.check_value_type("sample_num", sample_num, [int], self.name) self.pooled_height = pooled_height self.pooled_width = pooled_width self.spatial_scale = spatial_scale @@ -2338,24 +2230,24 @@ class Adam(PrimitiveWithInfer): @prim_attr_register def __init__(self, use_locking=False, use_nesterov=False): - validator.check_type("use_locking", use_locking, [bool]) - validator.check_type("use_nesterov", use_nesterov, [bool]) + validator.check_value_type("use_locking", use_locking, [bool], self.name) + validator.check_value_type("use_nesterov", use_nesterov, [bool], self.name) def infer_shape(self, var_shape, m_shape, v_shape, beta1_power_shape, beta2_power_shape, lr_shape, beta1_shape, beta2_shape, epsilon_shape, grad_shape): - validator.check_param_equal("var_shape", var_shape, "m_shape", m_shape) - validator.check_param_equal("var_shape", var_shape, "v_shape", v_shape) - validator.check_param_equal("var_shape", var_shape, "grad_shape", grad_shape) + validator.check("var_shape", var_shape, "m_shape", m_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "v_shape", v_shape, Rel.EQ, self.name) + validator.check("var_shape", var_shape, "grad_shape", grad_shape, Rel.EQ, self.name) return var_shape, m_shape, v_shape def infer_dtype(self, var_dtype, m_dtype, v_dtype, beta1_power_dtype, beta2_power_dtype, lr_dtype, beta1_dtype, beta2_dtype, epsilon_dtype, grad_dtype): - args = {"var_dtype": var_dtype, "m_dtype": m_dtype, "v_dtype": v_dtype, "grad_dtype": grad_dtype} - validator.check_type_same(args, mstype.number_type) + args = {"var": var_dtype, "m": m_dtype, "v": v_dtype, "grad": grad_dtype} + validator.check_tensor_type_same(args, mstype.number_type, self.name) - args = {"beta1_power_dtype": beta1_power_dtype, "beta2_power_dtype": beta2_power_dtype, 'lr_dtype': lr_dtype, - "beta1_dtype": beta1_dtype, "beta2_dtype": beta2_dtype, "epsilon_dtype": epsilon_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32]) + args = {"beta1_power": beta1_power_dtype, "beta2_power": beta2_power_dtype, 'lr': lr_dtype, + "beta1": beta1_dtype, "beta2": beta2_dtype, "epsilon": epsilon_dtype} + validator.check_scalar_or_tensor_type_same(args, [mstype.float16, mstype.float32], self.name, True) return var_dtype, m_dtype, v_dtype @@ -2397,12 +2289,12 @@ class BinaryCrossEntropy(PrimitiveWithInfer): @prim_attr_register def __init__(self, reduction='mean'): - self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum']) + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) def infer_shape(self, x_shape, y_shape, weight_shape): - validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape) + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) if weight_shape: - validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape) + validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name) if self.reduction in ('mean', 'sum'): shape = [] else: @@ -2410,10 +2302,11 @@ class BinaryCrossEntropy(PrimitiveWithInfer): return shape def infer_dtype(self, x_type, y_type, weight_type): - args = {'x_type': x_type, 'y_type': y_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + args = {'x': x_type, 'y': y_type} + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same(args, valid_types, self.name) if weight_type: - validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) + validator.check_tensor_type_same({'x': x_type, 'weight': weight_type}, valid_types, self.name) return x_type @@ -2445,27 +2338,22 @@ class SparseApplyAdagrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, lr, use_locking=False): - self.lr = validator.check_type("lr", lr, [float]) - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.lr = validator.check_value_type("lr", lr, [float], self.name) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, grad_shape, indices_shape): - validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape) - validator.check_param_equal('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape)) + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('len of var shape', len(var_shape), 'len of grad shape', len(grad_shape), Rel.EQ, self.name) if len(var_shape) > 1: - validator.check_param_equal('var_shape', var_shape[1:], 'grad_shape', grad_shape[1:]) - validator.check_integer("len of indices shape", len(indices_shape), 1, Rel.EQ) - validator.check('the first dimension of grad', grad_shape[0], - 'the shape of indices', indices_shape[0], Rel.EQ) + validator.check('var_shape[1:]', var_shape[1:], 'grad_shape[1:]', grad_shape[1:], Rel.EQ, self.name) + validator.check_integer("indices rank", len(indices_shape), 1, Rel.EQ, self.name) + validator.check('grad_shape[0]', grad_shape[0], 'indices_shape[0]', indices_shape[0], Rel.EQ, self.name) return var_shape def infer_dtype(self, var_type, accum_type, grad_type, indices_type): - validator.check_subclass("var_type", var_type, mstype.tensor) - validator.check_subclass("accum_type", accum_type, mstype.tensor) - validator.check_subclass("grad_type", grad_type, mstype.tensor) - validator.check_subclass("indices_type", indices_type, mstype.tensor) - args = {'var_type': var_type, 'accum_type': accum_type, 'grad_type': grad_type} - validator.check_type_same(args, (mstype.float32,)) - validator.check_typename('indices_type', indices_type, [mstype.int32]) + args = {'var': var_type, 'accum': accum_type, 'grad': grad_type} + validator.check_tensor_type_same(args, (mstype.float32,), self.name) + validator.check_tensor_type_same({'indices': indices_type}, [mstype.int32], self.name) return var_type @@ -2493,34 +2381,34 @@ class LARSUpdate(PrimitiveWithInfer): @prim_attr_register def __init__(self, epsilon=1e-05, hyperpara=0.001, use_clip=False): """init""" - validator.check_type("epsilon", epsilon, [float]) - validator.check_type("hyperpara", hyperpara, [float]) - validator.check_type("use_clip", use_clip, [bool]) + validator.check_value_type("epsilon", epsilon, [float], self.name) + validator.check_value_type("hyperpara", hyperpara, [float], self.name) + validator.check_value_type("use_clip", use_clip, [bool], self.name) def infer_shape(self, weight_shape, gradient_shape, norm_weight_shape, norm_gradient_shape, weight_decay_shape, learning_rate_shape): - validator.check_param_equal("Weight shape", weight_shape, "gradient shape", gradient_shape) - validator.check_param_equal("Norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape) + validator.check("weight shape", weight_shape, "gradient shape", gradient_shape, Rel.EQ, self.name) + validator.check("norm weight shape", norm_weight_shape, "norm gradient shape", norm_gradient_shape, Rel.EQ, + self.name) shp_len = len(weight_decay_shape) - validator.check_shape_length("Weight decay's shape", shp_len, 1, Rel.LE) + validator.check_integer("weight decay's rank", shp_len, 1, Rel.LE, self.name) if shp_len == 1: - validator.check_integer("Weight decay's shape", weight_decay_shape[0], 1, Rel.EQ) + validator.check_integer("weight_decay_shape[0]", weight_decay_shape[0], 1, Rel.EQ, self.name) shp_len = len(learning_rate_shape) - validator.check_shape_length("Learning rate's shape", shp_len, 1, Rel.LE) + validator.check_integer("learning rate's rank", shp_len, 1, Rel.LE, self.name) if shp_len == 1: - validator.check_integer("Learning rate's shape", learning_rate_shape[0], 1, Rel.EQ) + validator.check_integer("learning_rate_shape[0]", learning_rate_shape[0], 1, Rel.EQ, self.name) return weight_shape def infer_dtype(self, weight_dtype, gradient_dtype, norm_weight_dtype, norm_gradient_dtype, weight_decay_dtype, learning_rate_dtype): args = {"Weight dtype": weight_dtype, "gradient dtype": gradient_dtype, "norm weight dtype": norm_weight_dtype, "norm gradient dtype": norm_gradient_dtype} - validator.check_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32]) - validator.check_args_tensor(args) - validator.check_typename("weight_decay_dtype", weight_decay_dtype, - [mstype.float16, mstype.float32, mstype.float64]) - validator.check_typename("learning_rate_dtype", learning_rate_dtype, - [mstype.float16, mstype.float32, mstype.float64]) + validator.check_tensor_type_same(args, [mstype.float16, mstype.float32, mstype.int16, mstype.int32], self.name) + validator.check_scalar_or_tensor_type_same({"weight_decay": weight_decay_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) + validator.check_scalar_or_tensor_type_same({"learning_rate": learning_rate_dtype}, + [mstype.float16, mstype.float32, mstype.float64], self.name) return weight_dtype @@ -2553,26 +2441,23 @@ class ApplyFtrl(PrimitiveWithInfer): def __init__(self, use_locking=False): self.init_prim_io_names(inputs=['var', 'accum', 'linear', 'grad', 'lr', 'l1', 'l2', 'lr_power'], outputs=['output']) - self.use_locking = validator.check_type("use_locking", use_locking, [bool]) + self.use_locking = validator.check_value_type("use_locking", use_locking, [bool], self.name) def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, lr_shape, l1_shape, l2_shape, lr_power_shape): - validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape) - validator.check_param_equal('var shape', var_shape, 'linear shape', linear_shape) + validator.check('var shape', var_shape, 'accum shape', accum_shape, Rel.EQ, self.name) + validator.check('var shape', var_shape, 'linear shape', linear_shape, Rel.EQ, self.name) return var_shape def infer_dtype(self, var_type, accum_type, linear_type, grad_type, lr_type, l1_type, l2_type, lr_power_type): - validator.check_subclass("var_type", var_type, mstype.tensor) - validator.check_subclass("accum_type", accum_type, mstype.tensor) - validator.check_subclass("linear_type", linear_type, mstype.tensor) - validator.check_subclass("grad_type", grad_type, mstype.tensor) - args = {'var_type': var_type, 'accum_type': accum_type, 'linear_type': linear_type, 'grad_type': grad_type} - validator.check_type_same(args, (mstype.float32, mstype.float16)) - - validator.check_typename("lr", lr_type, [mstype.float16, mstype.float32]) - validator.check_typename("l1", l1_type, [mstype.float16, mstype.float32]) - validator.check_typename("l2", l2_type, [mstype.float16, mstype.float32]) - validator.check_typename("lr_power", lr_power_type, [mstype.float16, mstype.float32]) + valid_types = [mstype.float16, mstype.float32] + args = {'var': var_type, 'accum': accum_type, 'linear': linear_type, 'grad': grad_type} + validator.check_tensor_type_same(args, valid_types, self.name) + + validator.check_scalar_or_tensor_type_same({"lr": lr_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l1": l1_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"l2": l2_type}, valid_types, self.name) + validator.check_scalar_or_tensor_type_same({"lr_power": lr_power_type}, valid_types, self.name) return var_type @@ -2607,36 +2492,22 @@ class ExtractImagePatches(PrimitiveWithInfer): @prim_attr_register def __init__(self, ksizes, strides, rates, padding="valid"): """init""" - validator.check_type("ksizes", ksizes, [tuple, list]) - validator.check_type("strides", strides, [tuple, list]) - validator.check_type("rates", rates, [tuple, list]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + def _check_tuple_or_list(arg_name, arg_val, prim_name): + validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name) + if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[3] != 1: + raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, " + f"{arg_name}_col, 1], but got {arg_val}.") + if not isinstance(arg_val[1], int) or not isinstance(arg_val[2], int) or arg_val[1] < 1 or arg_val[2] < 1: + raise ValueError(f"For '{prim_name}' the {arg_name}_row and {arg_name}_col in {arg_name}s should be an " + f"positive integer number, but got {arg_name}_row is {arg_val[1]}, {arg_name}_col " + f"is {arg_val[2]}") + + _check_tuple_or_list("ksize", ksizes, self.name) + _check_tuple_or_list("stride", strides, self.name) + _check_tuple_or_list("rate", rates, self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) - if len(ksizes) != 4 or ksizes[0] != 1 or ksizes[3] != 1: - raise ValueError("The format of ksizes should be [1, ksize_row, ksize_col, 1], " - f"but got {ksizes}.") - if not isinstance(ksizes[1], int) or not isinstance(ksizes[2], int) or \ - ksizes[1] < 1 or ksizes[2] < 1: - raise ValueError("The ksize_row and ksize_col in ksizes should be an positive integer number, " - f"but got ksize_row is {ksizes[1]}, ksize_col is {ksizes[2]}") - - if len(strides) != 4 or strides[0] != 1 or strides[3] != 1: - raise ValueError("The format of strides should be [1, stride_row, stride_col, 1], " - f"but got {strides}.") - if not isinstance(strides[1], int) or not isinstance(strides[2], int) or \ - strides[1] < 1 or strides[2] < 1: - raise ValueError("The stride_row and stride_col in strides should be an positive integer number, " - f"but got stride_row is {strides[1]}, stride_col is {strides[2]}") - - if len(rates) != 4 or rates[0] != 1 or rates[3] != 1: - raise ValueError("The format of rates should be [1, rate_row, rate_col, 1], " - f"but got {rates}.") - if not isinstance(rates[1], int) or not isinstance(rates[2], int) or \ - rates[1] < 1 or rates[2] < 1: - raise ValueError("The rate_row and rate_col in rates should be an positive integer number, " - f"but got rate_row is {rates[1]}, rate_col is {rates[2]}") - def infer_shape(self, input_x): in_batch, in_row, in_col, in_depth = input_x _, ksize_row, ksize_col, _ = self.ksizes @@ -2662,6 +2533,5 @@ class ExtractImagePatches(PrimitiveWithInfer): return out_shape def infer_dtype(self, input_x): - validator.check_subclass("input_x", input_x, mstype.tensor) - validator.check_typename("input_x_dtype", input_x, (mstype.int8, mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) return input_x diff --git a/tests/ut/python/nn/test_dynamic_lr.py b/tests/ut/python/nn/test_dynamic_lr.py index 96f9d5afde..8d03be1766 100644 --- a/tests/ut/python/nn/test_dynamic_lr.py +++ b/tests/ut/python/nn/test_dynamic_lr.py @@ -41,7 +41,7 @@ class TestInputs: dr.piecewise_constant_lr(milestone1, learning_rates) milestone2 = [1.0, 2.0, True] - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.piecewise_constant_lr(milestone2, learning_rates) def test_learning_rates1(self): @@ -92,13 +92,13 @@ class TestInputs: def test_total_step1(self): total_step1 = 2.0 - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step1, step_per_epoch, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step1, step_per_epoch, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step1, step_per_epoch, decay_epoch, power) def test_total_step2(self): @@ -114,13 +114,13 @@ class TestInputs: def test_step_per_epoch1(self): step_per_epoch1 = True - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch1, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch1, decay_epoch) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch1, decay_epoch, power) def test_step_per_epoch2(self): @@ -136,13 +136,13 @@ class TestInputs: def test_decay_epoch1(self): decay_epoch1 = 'm' - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.exponential_decay_lr(learning_rate, decay_rate, total_step, step_per_epoch, decay_epoch1) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.cosine_decay_lr(min_lr, max_lr, total_step, step_per_epoch, decay_epoch1) - with pytest.raises(ValueError): + with pytest.raises(TypeError): dr.polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_epoch, decay_epoch1, power) def test_decay_epoch2(self): diff --git a/tests/ut/python/nn/test_ssim.py b/tests/ut/python/nn/test_ssim.py index cf946a1617..77d065b100 100644 --- a/tests/ut/python/nn/test_ssim.py +++ b/tests/ut/python/nn/test_ssim.py @@ -60,7 +60,7 @@ def test_ssim_max_val_zero(): net = SSIMNet(max_val) def test_ssim_filter_size_float(): - with pytest.raises(ValueError): + with pytest.raises(TypeError): net = SSIMNet(filter_size=1.1) def test_ssim_filter_size_zero(): diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index 09a4248c19..ab6f31095d 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -516,7 +516,7 @@ test_cases = [ test_cases_for_verify_exception = [ ('Conv2d_ValueError_1', { - 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': ValueError}), + 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}), 'desc_inputs': [0], }), ('Conv2d_ValueError_2', { @@ -528,7 +528,7 @@ test_cases_for_verify_exception = [ 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_2', { - 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': ValueError}), + 'block': (lambda _: P.MaxPoolWithArgmax(ksize='1'), {'exception': TypeError}), 'desc_inputs': [0], }), ('MaxPoolWithArgmax_ValueError_3', { @@ -540,7 +540,7 @@ test_cases_for_verify_exception = [ 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_1', { - 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': ValueError}), + 'block': (lambda _: P.FusedBatchNorm(mode="1", epsilon=1e-5, momentum=0.1), {'exception': TypeError}), 'desc_inputs': [0], }), ('FusedBatchNorm_ValueError_2', { @@ -560,31 +560,31 @@ test_cases_for_verify_exception = [ 'desc_inputs': [0], }), ('Softmax_ValueError_1', { - 'block': (lambda _: P.Softmax("1"), {'exception': ValueError}), + 'block': (lambda _: P.Softmax("1"), {'exception': TypeError}), 'desc_inputs': [0], }), ('Softmax_ValueError_2', { - 'block': (lambda _: P.Softmax(1.1), {'exception': ValueError}), + 'block': (lambda _: P.Softmax(1.1), {'exception': TypeError}), 'desc_inputs': [0], }), ('Softmax_ValueError_3', { - 'block': (lambda _: P.Softmax(axis="1"), {'exception': ValueError}), + 'block': (lambda _: P.Softmax(axis="1"), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_1', { - 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed0="seed0"), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_2', { - 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed0=1.0), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_3', { - 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed1="seed1"), {'exception': TypeError}), 'desc_inputs': [0], }), ('DropoutGenMask_ValueError_4', { - 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': ValueError}), + 'block': (lambda _: P.DropoutGenMask(Seed1=2.0), {'exception': TypeError}), 'desc_inputs': [0], }), ('MaxPool2d_ValueError_1', { diff --git a/tests/ut/python/ops/test_nn_ops_check.py b/tests/ut/python/ops/test_nn_ops_check.py new file mode 100755 index 0000000000..c2a751aa0c --- /dev/null +++ b/tests/ut/python/ops/test_nn_ops_check.py @@ -0,0 +1,463 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test ops """ +import functools +import numpy as np +from mindspore import ops +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.common import dtype as mstype +from mindspore.common.parameter import Parameter +from ..ut_filter import non_graph_engine +from mindspore.common.api import _executor + +from ....mindspore_test_framework.mindspore_test import mindspore_test +from ....mindspore_test_framework.pipeline.forward.compile_forward\ + import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config, + pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +from ....mindspore_test_framework.pipeline.gradient.compile_gradient\ + import pipeline_for_compile_grad_ge_graph_for_case_by_case_config + + +class Conv2DBackpropInputNet(nn.Cell): + def __init__(self, net, x_shape): + super(Conv2DBackpropInputNet, self).__init__() + self.net = net + self.x_shape = x_shape + + def construct(self, dout, w): + return self.net(dout, w, self.x_shape) + + +class TopKNet(nn.Cell): + def __init__(self, net, k): + super(TopKNet, self).__init__() + self.net = net + self.k = k + + def construct(self, x): + return self.net(x, self.k) + + +raise_set = [ + # input is scalar + ('Flatten0', { + 'block': (P.Flatten(), {'exception': TypeError, 'error_keywords': ['Flatten']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # dim of input is zero + ('Flatten1', { + 'block': (P.Flatten(), {'exception': ValueError, 'error_keywords': ['Flatten']}), + 'desc_inputs': [F.scalar_to_tensor(5.0)], + 'skip': ['backward']}), + + # input is scalar + ('Softmax0', { + 'block': (P.Softmax(), {'exception': TypeError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # axis is empty tuple + ('Softmax1', { + 'block': (P.Softmax(axis=()), {'exception': ValueError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + # axis value is not in range + ('Softmax2', { + 'block': (P.Softmax(axis=2), {'exception': ValueError, 'error_keywords': ['Softmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('LogSoftmax0', { + 'block': (P.LogSoftmax(), {'exception': TypeError, 'error_keywords': ['LogSoftmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # axis value is not in range + ('LogSoftmax1', { + 'block': (P.LogSoftmax(axis=2), {'exception': ValueError, 'error_keywords': ['LogSoftmax']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('ReLU0', { + 'block': (P.ReLU(), {'exception': TypeError, 'error_keywords': ['ReLU']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(Bool) + ('ReLU1', { + 'block': (P.ReLU(), {'exception': TypeError, 'error_keywords': ['ReLU']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], + 'skip': ['backward']}), + + # input is scalar + ('ReLU60', { + 'block': (P.ReLU6(), {'exception': TypeError, 'error_keywords': ['ReLU6']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('ReLU61', { + 'block': (P.ReLU6(), {'exception': TypeError, 'error_keywords': ['ReLU6']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Elu0', { + 'block': (P.Elu(), {'exception': TypeError, 'error_keywords': ['Elu']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('Elu1', { + 'block': (P.Elu(alpha=0.9), {'exception': TypeError, 'error_keywords': ['Elu']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Sigmoid0', { + 'block': (P.Sigmoid(), {'exception': TypeError, 'error_keywords': ['Sigmoid']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(int32) + ('Sigmoid1', { + 'block': (P.Sigmoid(), {'exception': TypeError, 'error_keywords': ['Sigmoid']}), + 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.int32))], + 'skip': ['backward']}), + + # input is scalar + ('Tanh0', { + 'block': (P.Tanh(), {'exception': TypeError, 'error_keywords': ['Tanh']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + + # input is scalar + ('BatchNorm0', { + 'block': (P.BatchNorm(is_training=False), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [5.0, 5.0, 5.0, 5.0, 5.0], + 'skip': ['backward']}), + # is_training=False and mean=None + ('BatchNorm1', { + 'block': (P.BatchNorm(is_training=False), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32)), + Tensor(np.ones([5, 3]).astype(np.float32)), None, None], + 'skip': ['backward']}), + # is_training=True and mean=None + ('BatchNorm2', { + 'block': (P.BatchNorm(is_training=True), {'exception': TypeError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float16)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # scale and bias rank > 1 + ('BatchNorm3', { + 'block': (P.BatchNorm(is_training=True), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32)), + Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # scale and bias shape not match + ('BatchNorm4', { + 'block': (P.BatchNorm(is_training=True), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([7]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + # is_training=False, mean and variance shape not match + ('BatchNorm5', { + 'block': (P.BatchNorm(is_training=False), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # is_training=False, mean and scale shape not match + ('BatchNorm6', { + 'block': (P.BatchNorm(is_training=False), {'exception': ValueError, 'error_keywords': ['BatchNorm']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32)), + Tensor(np.ones([3]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32)), + Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('Conv2D0', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('Conv2D1', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # input x and w type mismatch + ('Conv2D2', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': TypeError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float16))], + 'skip': ['backward']}), + # rank of x is not 4 + ('Conv2D3', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1, 1]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # rank of 2 is not 4 + ('Conv2D4', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9]).astype(np.float32))], + 'skip': ['backward']}), + # x_shape[1] / group != w_shape[1] + ('Conv2D5', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,2,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # out_channel != w_shape[0] + ('Conv2D6', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # kernel_size != w_shape[2:4] + ('Conv2D7', { + 'block': (P.Conv2D(2, (5, 5)), {'exception': ValueError, 'error_keywords': ['Conv2D']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([2,1,5,6]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('DepthwiseConv2dNative0', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('DepthwiseConv2dNative1', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # input x and w type mismatch + ('DepthwiseConv2dNative2', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': TypeError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float16))], + 'skip': ['backward']}), + # rank of x is not 4 + ('DepthwiseConv2dNative3', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1, 1]).astype(np.float32)), Tensor(np.ones([1,1,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # rank of 2 is not 4 + ('DepthwiseConv2dNative4', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,1,9]).astype(np.float32))], + 'skip': ['backward']}), + # x_shape[1] != w_shape[1] + ('DepthwiseConv2dNative5', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([1,2,9,9]).astype(np.float32))], + 'skip': ['backward']}), + # kernel_size != w_shape[2:4] + ('DepthwiseConv2dNative6', { + 'block': (P.DepthwiseConv2dNative(2, (5, 5)), + {'exception': ValueError, 'error_keywords': ['DepthwiseConv2dNative']}), + 'desc_inputs': [Tensor(np.ones([1,1,9,9]).astype(np.float32)), Tensor(np.ones([2,1,5,6]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('MaxPoolWithArgmax0', { + 'block': (P.MaxPoolWithArgmax(), {'exception': TypeError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('MaxPoolWithArgmax1', { + 'block': (P.MaxPoolWithArgmax(), {'exception': TypeError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPoolWithArgmax2', { + 'block': (P.MaxPoolWithArgmax(), {'exception': ValueError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # kernel size is invalid(very large) + ('MaxPoolWithArgmax3', { + 'block': (P.MaxPoolWithArgmax(ksize=50), {'exception': ValueError, 'error_keywords': ['MaxPoolWithArgmax']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('MaxPool0', { + 'block': (P.MaxPool(), {'exception': TypeError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPool1', { + 'block': (P.MaxPool(), {'exception': ValueError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x is not 4 + ('MaxPool2', { + 'block': (P.MaxPool(ksize=50, strides=1), {'exception': ValueError, 'error_keywords': ['MaxPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('AvgPool0', { + 'block': (P.AvgPool(), {'exception': TypeError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # rank of x is not 4 + ('AvgPool1', { + 'block': (P.AvgPool(), {'exception': ValueError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x is not 4 + ('AvgPool2', { + 'block': (P.AvgPool(ksize=50, strides=1), {'exception': ValueError, 'error_keywords': ['AvgPool']}), + 'desc_inputs': [Tensor(np.ones([1,1,32,32]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('Conv2DBackpropInput0', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('Conv2DBackpropInput1', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of doutput and w mismatch + ('Conv2DBackpropInput2', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2,3)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # types x_size is not tuple + ('Conv2DBackpropInput3', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), 2), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # types x_size is not tuple(int,...) + ('Conv2DBackpropInput4', { + 'block': (Conv2DBackpropInputNet(P.Conv2DBackpropInput(2, (5, 5)), (2, 3.0)), + {'exception': TypeError, 'error_keywords': ['Conv2DBackpropInput']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('BiasAdd0', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('BiasAdd1', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of x and bias mismatch + ('BiasAdd2', { + 'block': (P.BiasAdd(), {'exception': TypeError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.int32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # rank of x less than 2 + ('BiasAdd3', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # rank of bias is not equal to 1 + ('BiasAdd4', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5, 3]).astype(np.float32))], + 'skip': ['backward']}), + # b_shape[0] != x_shape[1] + ('BiasAdd5', { + 'block': (P.BiasAdd(), {'exception': ValueError, 'error_keywords': ['BiasAdd']}), + 'desc_inputs': [Tensor(np.ones([5, 3]).astype(np.float32)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + + # input x is scalar + ('TopK0', { + 'block': (TopKNet(P.TopK(), 5), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [5.0], + 'skip': ['backward']}), + # input x is Tensor(bool) + ('TopK1', { + 'block': (TopKNet(P.TopK(), 5), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [Tensor(np.ones([10]).astype(np.bool_))], + 'skip': ['backward']}), + # k is not integer + ('TopK2', { + 'block': (TopKNet(P.TopK(), 5.0), {'exception': TypeError, 'error_keywords': ['TopK']}), + 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('SoftmaxCrossEntropyWithLogits0', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # input is Tensor(bool) + ('SoftmaxCrossEntropyWithLogits1', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # types of logits and labels mismatch + ('SoftmaxCrossEntropyWithLogits2', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float16)), Tensor(np.ones([5]).astype(np.float32))], + 'skip': ['backward']}), + # shapes of logits and labels mismatch + ('SoftmaxCrossEntropyWithLogits3', { + 'block': (P.SoftmaxCrossEntropyWithLogits(), + {'exception': ValueError, 'error_keywords': ['SoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([3]).astype(np.float32))], + 'skip': ['backward']}), + + # input is scalar + ('SparseSoftmaxCrossEntropyWithLogits0', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [5.0, 5.0], + 'skip': ['backward']}), + # logits is Tensor(bool) + ('SparseSoftmaxCrossEntropyWithLogits1', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.bool_)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # labels is Tensor(bool) + ('SparseSoftmaxCrossEntropyWithLogits2', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': TypeError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([5]).astype(np.bool_))], + 'skip': ['backward']}), + # logits_shape[0] != labels_shape[0] + ('SparseSoftmaxCrossEntropyWithLogits3', { + 'block': (P.SparseSoftmaxCrossEntropyWithLogits(), + {'exception': ValueError, 'error_keywords': ['SparseSoftmaxCrossEntropyWithLogits']}), + 'desc_inputs': [Tensor(np.ones([5]).astype(np.float32)), Tensor(np.ones([3]).astype(np.int32))], + 'skip': ['backward']}), +] + + +@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception) +def test_check_exception(): + return raise_set From b87643958fa305d86a36b88e1ce20bba910d18a0 Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Tue, 21 Apr 2020 06:52:08 +0000 Subject: [PATCH 054/142] Parameter power of polynomial_decay_lr should be greater than 0 --- mindspore/nn/dynamic_lr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index beed7a0186..fb4d229a9d 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -269,7 +269,7 @@ def polynomial_decay_lr(learning_rate, end_learning_rate, total_step, step_per_e total_step (int): The total number of steps. step_per_epoch (int): The number of steps in per epoch. decay_epoch (int): A value used to calculate decayed learning rate. - power (float): A value used to calculate decayed learning rate. + power (float): A value used to calculate decayed learning rate. This parameter should be greater than 0. update_decay_epoch (bool): If true, update `decay_epoch`. Default: False. Returns: From 4740c70fc344286f4752f490ce6333143c6ad6c9 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Tue, 21 Apr 2020 18:50:54 +0800 Subject: [PATCH 055/142] gpu add testcases --- mindspore/_akg/gpu/__init__.py | 5 ++ mindspore/nn/wrap/loss_scale.py | 3 + mindspore/ops/_op_impl/akg/gpu/__init__.py | 5 ++ mindspore/ops/operations/math_ops.py | 1 + tests/st/ops/gpu/test_lessequal_op.py | 49 ++++++++++++ tests/st/ops/gpu/test_logical_op.py | 92 ++++++++++++++++++++++ tests/st/ops/gpu/test_maximum_op.py | 55 +++++++++++++ 7 files changed, 210 insertions(+) create mode 100644 tests/st/ops/gpu/test_lessequal_op.py create mode 100644 tests/st/ops/gpu/test_logical_op.py create mode 100644 tests/st/ops/gpu/test_maximum_op.py diff --git a/mindspore/_akg/gpu/__init__.py b/mindspore/_akg/gpu/__init__.py index 08961d3989..f9db48c634 100644 --- a/mindspore/_akg/gpu/__init__.py +++ b/mindspore/_akg/gpu/__init__.py @@ -30,3 +30,8 @@ from .hsigmoid import HSigmoid, gpu_schedule_HSigmoid from .hsigmoid_grad import HSigmoidGrad, gpu_schedule_HSigmoidGrad from .hswish import HSwish, gpu_schedule_HSwish from .hswish_grad import HSwishGrad, gpu_schedule_HSwishGrad +from .logical_or import LogicalOr, gpu_schedule_LogicalOr +from .logical_not import LogicalNot, gpu_schedule_LogicalNot +from .logical_and import LogicalAnd, gpu_schedule_LogicalAnd +from .sub import Sub, gpu_schedule_Sub +from .less_equal import LessEqual, gpu_schedule_LessEqual diff --git a/mindspore/nn/wrap/loss_scale.py b/mindspore/nn/wrap/loss_scale.py index ba8e6cbb7c..65d66f0150 100644 --- a/mindspore/nn/wrap/loss_scale.py +++ b/mindspore/nn/wrap/loss_scale.py @@ -209,6 +209,7 @@ class TrainOneStepWithLossScaleCell(Cell): self.gpu_target = True self.float_status = P.FloatStatus() self.addn = P.AddN() + self.reshape = P.Reshape() else: self.gpu_target = False self.alloc_status = NPUAllocFloatStatus() @@ -260,6 +261,8 @@ class TrainOneStepWithLossScaleCell(Cell): else: flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) flag_sum = self.addn(flag_sum) + # convert flag_sum to scalar + flag_sum = self.reshape(flag_sum, (())) if self.is_distributed: # sum overflow flag over devices flag_reduce = self.allreduce(flag_sum) diff --git a/mindspore/ops/_op_impl/akg/gpu/__init__.py b/mindspore/ops/_op_impl/akg/gpu/__init__.py index 8ffc796ae3..08beb44340 100644 --- a/mindspore/ops/_op_impl/akg/gpu/__init__.py +++ b/mindspore/ops/_op_impl/akg/gpu/__init__.py @@ -27,3 +27,8 @@ from .hsigmoid import _hsigmoid_akg from .hsigmoid_grad import _hsigmoid_grad_akg from .hswish import _hswish_akg from .hswish_grad import _hswish_grad_akg +from .sub import _sub_akg +from .logical_and import _logical_and_akg +from .logical_not import _logical_not_akg +from .logical_or import _logical_or_akg +from .lessequal import _lessequal_akg diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index a3df6b7fba..78d813b9cc 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1495,6 +1495,7 @@ class LogicalNot(PrimitiveWithInfer): @prim_attr_register def __init__(self): """init LogicalNot""" + self.init_prim_io_names(inputs=['x'], outputs=['output']) def infer_shape(self, x_shape): return x_shape diff --git a/tests/st/ops/gpu/test_lessequal_op.py b/tests/st/ops/gpu/test_lessequal_op.py new file mode 100644 index 0000000000..08bb28b0af --- /dev/null +++ b/tests/st/ops/gpu/test_lessequal_op.py @@ -0,0 +1,49 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +import mindspore.context as context +import numpy as np + + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.lessequal = P.LessEqual() + + def construct(self, x, y): + return self.lessequal(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_lessequal(): + x = Tensor(np.array([[1, 2, 3]]).astype(np.float32)) + y = Tensor(np.array([[2]]).astype(np.float32)) + expect = [[True, True, False]] + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + lessequal = Net() + output = lessequal(x, y) + assert np.all(output.asnumpy() == expect) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + lessequal = Net() + output = lessequal(x, y) + assert np.all(output.asnumpy() == expect) + diff --git a/tests/st/ops/gpu/test_logical_op.py b/tests/st/ops/gpu/test_logical_op.py new file mode 100644 index 0000000000..ab95aa8f3f --- /dev/null +++ b/tests/st/ops/gpu/test_logical_op.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +import mindspore.context as context +import numpy as np + + +class NetAnd(Cell): + def __init__(self): + super(NetAnd, self).__init__() + self.logicaland = P.LogicalAnd() + + def construct(self, x, y): + return self.logicaland(x, y) + +class NetOr(Cell): + def __init__(self): + super(NetOr, self).__init__() + self.logicalor = P.LogicalOr() + + def construct(self, x, y): + return self.logicalor(x, y) + +class NetNot(Cell): + def __init__(self): + super(NetNot, self).__init__() + self.logicalnot = P.LogicalNot() + + def construct(self, x): + return self.logicalnot(x) + +x = np.array([True, False, False]).astype(np.bool) +y = np.array([False]).astype(np.bool) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_logicaland(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + logicaland = NetAnd() + output = logicaland(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_and(x, y)) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + logicaland = NetAnd() + output = logicaland(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_and(x, y)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_logicalor(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + logicalor = NetOr() + output = logicalor(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_or(x, y)) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + logicalor = NetOr() + output = logicalor(Tensor(x), Tensor(y)) + assert np.all(output.asnumpy() == np.logical_or(x, y)) + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_logicalnot(): + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + logicalnot = NetNot() + output = logicalnot(Tensor(x)) + assert np.all(output.asnumpy() == np.logical_not(x)) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + logicalnot = NetNot() + output = logicalnot(Tensor(x)) + assert np.all(output.asnumpy() == np.logical_not(x)) + diff --git a/tests/st/ops/gpu/test_maximum_op.py b/tests/st/ops/gpu/test_maximum_op.py new file mode 100644 index 0000000000..3193dafa61 --- /dev/null +++ b/tests/st/ops/gpu/test_maximum_op.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +import mindspore.context as context +import numpy as np + + +class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.max = P.Maximum() + + def construct(self, x, y): + return self.max(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_max(): + x = Tensor(np.array([[1, 2, 3]]).astype(np.float32)) + y = Tensor(np.array([[2]]).astype(np.float32)) + expect = [[2, 2, 3]] + error = np.ones(shape=[1, 3]) * 1.0e-5 + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + max = Net() + output = max(x, y) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + max = Net() + output = max(x, y) + diff = output.asnumpy() - expect + assert np.all(diff < error) + assert np.all(-diff < error) + From a931fbbae260ff497da18b623962fc5d1ddbe860 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Tue, 21 Apr 2020 19:46:10 +0800 Subject: [PATCH 056/142] delete unsafe compile option --- mindspore/ccsrc/dataset/CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index 0bc4065ac9..879a9346bc 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -12,9 +12,6 @@ endif() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") -if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,--image-base -Wl,0x10000000") -endif() ############################# Options ################################ if (ENABLE_GPUQUE) add_definitions(-D ENABLE_GPUQUE) From 1d40115afdc18cb154e087397f007574d54f80de Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Tue, 21 Apr 2020 19:48:06 +0800 Subject: [PATCH 057/142] fix clang-format, pointer and referance should be always right-align --- .clang-format | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.clang-format b/.clang-format index c931e8f068..3b26784000 100644 --- a/.clang-format +++ b/.clang-format @@ -52,7 +52,7 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true ConstructorInitializerIndentWidth: 4 ContinuationIndentWidth: 2 Cpp11BracedListStyle: true -DerivePointerAlignment: true +DerivePointerAlignment: false DisableFormat: false ExperimentalAutoDetectBinPacking: false FixNamespaceComments: true From 9b28d9bd4a44e9b9b64dd1ce94f0af2b8497f018 Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Tue, 21 Apr 2020 03:14:26 +0000 Subject: [PATCH 058/142] Add comment about int type. --- mindspore/nn/optim/optimizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 5738044532..b4bead2a77 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -45,8 +45,10 @@ class Optimizer(Cell): learning_rate (float): A floating point value for the learning rate. Should be greater than 0. parameters (list): A list of parameter, which will be updated. The element in `parameters` should be class mindspore.Parameter. - weight_decay (float): A floating point value for the weight decay. Default: 0.0. - loss_scale (float): A floating point value for the loss scale. Default: 1.0. Should be greater than 0. + weight_decay (float): A floating point value for the weight decay. If the type of `weight_decay` + input is int, it will be convertd to float. Default: 0.0. + loss_scale (float): A floating point value for the loss scale. It should be greater than 0. If the + type of `loss_scale` input is int, it will be convertd to float. Default: 1.0. decay_filter (Function): A function to determine whether to apply weight decay on parameters. Default: lambda x: 'beta' not in x.name and 'gamma' not in x.name. From b812b18c028df16b6ff08e456bb435cf50e10442 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Fri, 17 Apr 2020 12:03:50 +0800 Subject: [PATCH 059/142] support update parameter for vm --- mindspore/common/parameter.py | 16 ++-- mindspore/nn/optim/adam.py | 12 --- mindspore/nn/optim/ftrl.py | 2 +- mindspore/nn/optim/lars.py | 17 ----- mindspore/nn/optim/momentum.py | 20 +---- mindspore/nn/optim/optimizer.py | 2 + mindspore/nn/optim/rmsprop.py | 21 +----- mindspore/nn/optim/sgd.py | 20 +---- mindspore/nn/wrap/cell_wrapper.py | 30 +------- mindspore/train/serialization.py | 11 ++- tests/ut/python/nn/test_cell_wrapper.py | 4 - tests/ut/python/nn/test_parameter.py | 74 ------------------- tests/ut/python/ops/test_momentum.py | 2 +- .../python/pynative_mode/test_cell_bprop.py | 2 +- tests/vm_impl/nn_ops_vm_impl.py | 4 +- 15 files changed, 34 insertions(+), 203 deletions(-) diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index c354bcd235..5f56d23956 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -15,7 +15,6 @@ """Parameter for cell.""" from copy import copy, deepcopy -import numpy as np from .initializer import initializer from .tensor import Tensor from .._checkparam import _check_str_by_regular @@ -176,14 +175,15 @@ class Parameter: return res def set_parameter_data(self, data): - if isinstance(data, (Tensor, list, int, float, - np.float16, np.float32, np.int32, np.int16, np.ndarray)) and not isinstance(data, bool): - if isinstance(data, Tensor): - # make a copy of Tensor to init the parameter - data = Tensor(data.asnumpy().copy()) - self.default_input = data + """Set `default_input` of current `Parameter`.""" + if isinstance(data, bool): + raise ValueError('Parameter data can not be `bool`') + if isinstance(data, Tensor): + # make a copy of Tensor to init the parameter + data = Tensor(data.asnumpy().copy()) else: - raise ValueError("Parameter data must be tensor or number.") + data = Tensor(data) + self.default_input = data class ParameterTuple(tuple): diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 65f8ec678b..4e88c3ef93 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -101,17 +101,6 @@ def _run_opt_with_one_number(opt, lr, beta1_power, beta2_power, beta1, beta2, ep return success -@adam_opt.register("Function", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Number", "Tensor", "Tensor", "Tensor", - "Tensor") -def _run_opt_with_two_number(opt, lr, beta1_power, beta2_power, beta1, beta2, eps, gradient, params, moment1, - moment2): - """Apply adam optimizer to the weight parameter using Tensor.""" - success = True - success = F.depend(success, opt(params, moment1, moment2, beta1_power, beta2_power, lr, beta1, beta2, - eps, gradient)) - return success - - class Adam(Optimizer): r""" Updates gradients by Adaptive Moment Estimation (Adam) algorithm. @@ -183,7 +172,6 @@ class Adam(Optimizer): self.moment1 = self.parameters.clone(prefix="moment1", init='zeros') self.moment2 = self.parameters.clone(prefix="moment2", init='zeros') - self.decay_tf = tuple(decay_filter(x) for x in self.parameters) self.hyper_map = C.HyperMap() self.opt = P.Adam(use_locking, use_nesterov) diff --git a/mindspore/nn/optim/ftrl.py b/mindspore/nn/optim/ftrl.py index d08dd6cf4c..2bc329f42d 100644 --- a/mindspore/nn/optim/ftrl.py +++ b/mindspore/nn/optim/ftrl.py @@ -23,7 +23,7 @@ from mindspore._checkparam import Rel from .optimizer import Optimizer, apply_decay, grad_scale ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") -@ftrl_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") +@ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): """Apply ftrl optimizer to the weight parameter.""" success = True diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 02538aa61a..73451f3bf5 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -43,23 +43,6 @@ def _tensor_run_opt(lars, weight_decay, learning_rate, gradient, weight, decay_f return gradient -@lars_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Bool", "Bool") -def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, decay_flag, lars_flag): - """Apply lars optimizer to the weight parameter.""" - if lars_flag: - op_reduce = P.ReduceSum() - w_square_sum = op_reduce(F.square(weight)) - grad_square_sum = op_reduce(F.square(gradient)) - if decay_flag: - grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, weight_decay, learning_rate) - else: - num_zero = 0.0 - grad_t = lars(weight, gradient, w_square_sum, grad_square_sum, num_zero, learning_rate) - return grad_t - - return gradient - - class LARS(Optimizer): """ Implements the LARS algorithm with LARSUpdate Operator. diff --git a/mindspore/nn/optim/momentum.py b/mindspore/nn/optim/momentum.py index bac8e74a42..c69e226df9 100755 --- a/mindspore/nn/optim/momentum.py +++ b/mindspore/nn/optim/momentum.py @@ -15,19 +15,13 @@ """momentum""" from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype from .optimizer import Optimizer momentum_opt = C.MultitypeFuncGraph("momentum_opt") -@momentum_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, moment): - """Apply momentum optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) - return success - - @momentum_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): """Apply momentum optimizer to the weight parameter using Tensor.""" @@ -36,14 +30,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, moment): return success -@momentum_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, moment): - """Apply momentum optimizer to the weight parameter using dynamic learning rate.""" - success = True - success = F.depend(success, opt(weight, moment, learning_rate, gradient, momentum)) - return success - - class Momentum(Optimizer): """ Implements the Momentum algorithm. @@ -86,7 +72,7 @@ class Momentum(Optimizer): super(Momentum, self).__init__(learning_rate, params, weight_decay, loss_scale, decay_filter) if isinstance(momentum, float) and momentum < 0.0: raise ValueError("momentum should be at least 0.0, but got momentum {}".format(momentum)) - self.momentum = Parameter(momentum, name="momentum") + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.params = self.parameters self.moments = self.params.clone(prefix="moments", init='zeros') self.hyper_map = C.HyperMap() diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 5738044532..8a7c65e5b2 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -22,6 +22,7 @@ from mindspore.ops import functional as F, composite as C, operations as P from mindspore.nn.cell import Cell from mindspore.common.parameter import Parameter, ParameterTuple from mindspore.common.initializer import initializer +import mindspore.common.dtype as mstype from mindspore._checkparam import Validator as validator from mindspore._checkparam import Rel from mindspore.common.tensor import Tensor @@ -64,6 +65,7 @@ class Optimizer(Cell): self.assignadd = None self.global_step = None validator.check_number_range("learning rate", learning_rate, 0.0, float("inf"), Rel.INC_LEFT, self.cls_name) + learning_rate = Tensor(learning_rate, mstype.float32) else: self.dynamic_lr = True self.gather = P.GatherV2() diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index 97d7538a26..a8f118b709 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -21,34 +21,17 @@ rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") centered_rmsprop_opt = C.MultitypeFuncGraph("rmsprop_opt") -@rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): - """Apply rmsprop optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) - return success - - @rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): +def _rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, ms, mom, grad): """Apply rmsprop optimizer to the weight parameter using dynamic learning rate.""" success = True success = F.depend(success, opt(weight, ms, mom, grad, learning_rate, decay, momentum, epsilon)) return success -@centered_rmsprop_opt.register("Function", "Number", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", - "Tensor", "Tensor") -def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): - """Apply centered rmsprop optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) - return success - - @centered_rmsprop_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") -def _centered_rmsprop_opt_dynamic_lr(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): +def _centered_rmsprop_opt(opt, learning_rate, decay, epsilon, momentum, weight, mg, ms, mom, grad): """Apply centered rmsprop optimizer to the weight parameter using dynamic learning rate.""" success = True success = F.depend(success, opt(weight, mg, ms, mom, grad, learning_rate, decay, momentum, epsilon)) diff --git a/mindspore/nn/optim/sgd.py b/mindspore/nn/optim/sgd.py index db0775e023..cda5aa904a 100755 --- a/mindspore/nn/optim/sgd.py +++ b/mindspore/nn/optim/sgd.py @@ -15,20 +15,14 @@ """sgd""" from mindspore.ops import functional as F, composite as C, operations as P from mindspore.common.parameter import Parameter +from mindspore.common.tensor import Tensor +import mindspore.common.dtype as mstype from mindspore._checkparam import Validator as validator from .optimizer import Optimizer sgd_opt = C.MultitypeFuncGraph("sgd_opt") -@sgd_opt.register("Function", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt(opt, learning_rate, momentum, gradient, weight, accum, stat): - """Apply sgd optimizer to the weight parameter.""" - success = True - success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat)) - return success - - @sgd_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, stat): """Apply sgd optimizer to the weight parameter using Tensor.""" @@ -37,14 +31,6 @@ def _tensor_run_opt_ext(opt, learning_rate, momentum, gradient, weight, accum, s return success -@sgd_opt.register("Function", "Tensor", "Number", "Tensor", "Tensor", "Tensor", "Tensor") -def _tensor_run_opt_dyn(opt, learning_rate, momentum, gradient, weight, accum, stat): - """Apply sgd optimizer to the weight parameter using dynamic learning rate.""" - success = True - success = F.depend(success, opt(weight, gradient, learning_rate, accum, momentum, stat)) - return success - - class SGD(Optimizer): """ Implements stochastic gradient descent (optionally with momentum). @@ -105,7 +91,7 @@ class SGD(Optimizer): self.opt = P.SGD(dampening, weight_decay, nesterov) - self.momentum = Parameter(momentum, name="momentum") + self.momentum = Parameter(Tensor(momentum, mstype.float32), name="momentum") self.accum = self.parameters.clone(prefix="accum", init='zeros') self.stat = self.parameters.clone(prefix="stat", init='ones') self.hyper_map = C.HyperMap() diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 64c382557a..6c88b7d957 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -13,17 +13,10 @@ # limitations under the License. # ============================================================================ """Cell_wrapper.""" -import copy - -import numpy as np - from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, _get_parallel_mode) from mindspore.train.parallel_utils import ParallelMode - -from ...common import Tensor from ...common import dtype as mstype -from ...common.initializer import initializer from ...common.parameter import Parameter, ParameterTuple from ...ops import composite as C from ...ops import functional as F @@ -348,25 +341,8 @@ class ParameterUpdate(Cell): super(ParameterUpdate, self).__init__(auto_prefix=False) if not isinstance(param, Parameter): raise TypeError("`param` must be `Parameter`, but got {}".format(param)) - - default_input = param.default_input - if isinstance(default_input, Tensor): - shape = default_input.shape() - zero_dtype = default_input.dtype() - elif isinstance(default_input, float): - shape = [1] - zero_dtype = mstype.float32 - elif isinstance(default_input, int): - shape = [1] - zero_dtype = mstype.int32 - else: - raise TypeError("`default_input` in `param` must be Tensor, float or int, but got {}".format(default_input)) - - self._param = Parameter(initializer(copy.deepcopy(default_input), shape), param.name) - self._param.is_init = True - self._zero = Tensor(np.zeros(shape), zero_dtype) + self._param = param def construct(self, x): - zero = self._param + self._zero - F.control_depend(zero, F.assign(self._param, x)) - return zero + self._param = x + return x diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index ae17bf8116..e933d40666 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -36,7 +36,6 @@ tensor_to_ms_type = {"Int8": mstype.int8, "Int16": mstype.int16, "Int32": mstype tensor_to_np_type = {"Int8": np.int8, "Int16": np.int16, "Int32": np.int32, "Int64": np.int64, "Float16": np.float16, "Float32": np.float32, "Float64": np.float64} - def _special_process_par(par, new_par): """ Processes the special condition. @@ -182,8 +181,14 @@ def load_checkpoint(ckpoint_file_name, net=None): param_data = np.fromstring(data, np_type) dims = element.tensor.dims - if dims in [[0], [1]]: - parameter_dict[element.tag] = Parameter(param_data[0], name=element.tag) + if dims == [0]: + if 'Float' in data_type: + param_data = float(param_data[0]) + elif 'Int' in data_type: + param_data = int(param_data[0]) + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) + elif dims == [1]: + parameter_dict[element.tag] = Parameter(Tensor(param_data, ms_type), name=element.tag) else: param_dim = [] for dim in dims: diff --git a/tests/ut/python/nn/test_cell_wrapper.py b/tests/ut/python/nn/test_cell_wrapper.py index 3e163c9e4f..148d42ab64 100755 --- a/tests/ut/python/nn/test_cell_wrapper.py +++ b/tests/ut/python/nn/test_cell_wrapper.py @@ -94,10 +94,6 @@ def test_parameter_update_float32(): def test_parameter_update_error(): """ test_parameter_update """ input_np = np.array([1]) - input_parameter = Parameter(np.array([1]), 'input_parameter') with pytest.raises(TypeError): ParameterUpdate(input_np) - - with pytest.raises(TypeError): - ParameterUpdate(input_parameter) diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 49e89e124e..529af532f7 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -52,86 +52,12 @@ def test_parameter_tuple_illegal(): def test_parameter_init_illegal(): - import numpy as np - dat = np.array([[1, 2, 3], [2, 3, 4]]) - tensor = Tensor(dat) - data_none = None data_bool = True data_str = "nicai" - data_int = 3 - data_list = [1, "2", True] - data_tuple = (1, 2, 3) - np_arr_int16 = np.ones([1,1], dtype=np.int16) - np_arr_int32 = np.ones([1,1], dtype=np.int32) - np_arr_float16 = np.ones([1,1], dtype=np.float16) - np_arr_float32 = np.ones([1,1], dtype=np.float32) - -# with pytest.raises(ValueError): -# Parameter(np_arr_int16[0][0], name=data_str) - Parameter(np_arr_int32[0], name=data_str) - Parameter(np_arr_float16[0], name=data_str) - Parameter(np_arr_float32[0], name=data_str) - Parameter(np_arr_float32, name=data_str) - - Parameter(tensor, name=data_str) - Parameter(data_int, name=data_str) - Parameter(dat, name=data_str) - with pytest.raises(ValueError): - Parameter(data_none, name=data_str) with pytest.raises(ValueError): Parameter(data_bool, name=data_str) - with pytest.raises(ValueError): - Parameter(data_str, name=data_str) - Parameter(data_list, name=data_str) - with pytest.raises(ValueError): - Parameter(data_tuple, name=data_str) - - Parameter(tensor, name=data_str) - Parameter(tensor, name=data_none) - with pytest.raises(ValueError): - Parameter(tensor, name=dat) - with pytest.raises(ValueError): - Parameter(tensor, name=tensor) - with pytest.raises(ValueError): - Parameter(tensor, name=data_bool) - with pytest.raises(ValueError): - Parameter(tensor, name=data_int) - with pytest.raises(ValueError): - Parameter(tensor, name=data_list) - with pytest.raises(ValueError): - Parameter(tensor, name=data_tuple) - Parameter(tensor, name=data_str, requires_grad=data_bool) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_none) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=dat) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=tensor) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_str) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_int) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_list) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_tuple) - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_bool) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=dat) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=tensor) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_none) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_str) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_int) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_list) - with pytest.raises(TypeError): - Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_tuple) def test_check_str_by_regular(): diff --git a/tests/ut/python/ops/test_momentum.py b/tests/ut/python/ops/test_momentum.py index 64b5a9af12..3334f1670a 100644 --- a/tests/ut/python/ops/test_momentum.py +++ b/tests/ut/python/ops/test_momentum.py @@ -31,7 +31,7 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \ run_opt = C.MultitypeFuncGraph("run_opt") -@run_opt.register("Function", "Int", "Number", "Number", +@run_opt.register("Function", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor", "Tensor") def tensor_run_opt(opt, iters, learning_rate, momentum, diff --git a/tests/ut/python/pynative_mode/test_cell_bprop.py b/tests/ut/python/pynative_mode/test_cell_bprop.py index da1e14974f..c69b80412e 100644 --- a/tests/ut/python/pynative_mode/test_cell_bprop.py +++ b/tests/ut/python/pynative_mode/test_cell_bprop.py @@ -51,7 +51,7 @@ class InlineMulADD(nn.Cell): def __init__(self): super(InlineMulADD, self).__init__() self.mul_add = MulAdd() - self.param = Parameter(2, 'param') + self.param = 2 def construct(self, x, y): return self.mul_add(x, y) + x + self.param * y diff --git a/tests/vm_impl/nn_ops_vm_impl.py b/tests/vm_impl/nn_ops_vm_impl.py index 8794acbbd2..0df4b5fbaa 100644 --- a/tests/vm_impl/nn_ops_vm_impl.py +++ b/tests/vm_impl/nn_ops_vm_impl.py @@ -377,8 +377,8 @@ def vm_impl_momentum(self): accumulation = accumulation.asnumpy() variable = variable.asnumpy() shape = accumulation.shape - learning_rate = np.full(shape, learning_rate) - momentum = np.full(shape, momentum) + learning_rate = np.full(shape, learning_rate.asnumpy()) + momentum = np.full(shape, momentum.asnumpy()) accumulation = accumulation * momentum + gradient if use_nesterov is True: variable -= gradient * learning_rate + accumulation * momentum * learning_rate From c2129b9190a552a31499eef178cd01ce116bd0ee Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Tue, 21 Apr 2020 12:38:33 +0000 Subject: [PATCH 060/142] Modify wrong type description. --- mindspore/nn/dynamic_lr.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/dynamic_lr.py b/mindspore/nn/dynamic_lr.py index 266587c5c3..b88a7b7355 100644 --- a/mindspore/nn/dynamic_lr.py +++ b/mindspore/nn/dynamic_lr.py @@ -32,7 +32,7 @@ def piecewise_constant_lr(milestone, learning_rates): Args: milestone (Union[list[int], tuple[int]]): A list of milestone. This list is a monotone increasing list. - learning_rates (Union[list[float], tuple[int]]): A list of learning rates. + learning_rates (Union[list[float], tuple[float]]): A list of learning rates. Returns: list[float]. The size of list is :math:`M_N`. From 08968c2744c3b71ce950fec76e94b5d200ea2e55 Mon Sep 17 00:00:00 2001 From: dengwentao Date: Tue, 21 Apr 2020 10:49:59 +0800 Subject: [PATCH 061/142] modify tvm build --- CMakeLists.txt | 2 - cmake/external_libs/dmlc_core.cmake | 2 +- cmake/external_libs/tvm_gpu.cmake | 13 +- cmake/package.cmake | 12 +- cmake/utils.cmake | 28 +++-- mindspore/ccsrc/CMakeLists.txt | 111 ------------------ .../patch/incubator-tvm/CMakeLists.txt | 100 ++++++++++++++++ .../patch/incubator-tvm/find_library.patch | 8 +- 8 files changed, 145 insertions(+), 131 deletions(-) create mode 100644 third_party/patch/incubator-tvm/CMakeLists.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 46804c8dde..7dceca7ad7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,8 +1,6 @@ cmake_minimum_required(VERSION 3.14) project (MindSpore) - include(${CMAKE_SOURCE_DIR}/cmake/options.cmake) - set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_SOURCE_DIR}/cmake/modules/") if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") diff --git a/cmake/external_libs/dmlc_core.cmake b/cmake/external_libs/dmlc_core.cmake index 386a52429d..e07df83fd6 100644 --- a/cmake/external_libs/dmlc_core.cmake +++ b/cmake/external_libs/dmlc_core.cmake @@ -1,4 +1,4 @@ -mindspore_add_pkg(dmlc_core +mindspore_add_pkg(dmlc-core VER 0.3 HEAD_ONLY ./ URL https://github.com/dmlc/dmlc-core/archive/808f485387f9a03f78fa9f1159f387d0d91b7a28.zip diff --git a/cmake/external_libs/tvm_gpu.cmake b/cmake/external_libs/tvm_gpu.cmake index 57a045cb03..834e2d159d 100644 --- a/cmake/external_libs/tvm_gpu.cmake +++ b/cmake/external_libs/tvm_gpu.cmake @@ -2,7 +2,14 @@ set(incubator_tvm_gpu_CXXFLAGS "-D_FORTIFY_SOURCE=2 -O2") set(incubator_tvm_gpu_CFLAGS "-D_FORTIFY_SOURCE=2 -O2") mindspore_add_pkg(incubator_tvm_gpu VER 0.6.0 - HEAD_ONLY ./ + LIBS tvm URL https://github.com/apache/incubator-tvm/archive/v0.6.0.tar.gz - MD5 9cbbd32545a776023acabbba270449fe) - + MD5 9cbbd32545a776023acabbba270449fe + CUSTOM_CMAKE ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/ + SUBMODULES ${dlpack_DIRPATH} ${dmlc-core_DIRPATH} ${rang_DIRPATH} + SOURCEMODULES topi/python/topi python/tvm + PATCHES ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/find_library.patch + ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/include.patch + ${CMAKE_SOURCE_DIR}/third_party/patch/incubator-tvm/src_pass.patch + CMAKE_OPTION " ") +add_library(mindspore::tvm ALIAS incubator_tvm_gpu::tvm) \ No newline at end of file diff --git a/cmake/package.cmake b/cmake/package.cmake index 531dff29ca..d35ce0463b 100644 --- a/cmake/package.cmake +++ b/cmake/package.cmake @@ -191,11 +191,17 @@ if (ENABLE_GPU) DESTINATION ${INSTALL_PY_DIR}/../ COMPONENT mindspore ) - if (EXISTS ${CMAKE_BINARY_DIR}/incubator-tvm) + if (EXISTS ${incubator_tvm_gpu_ROOT}) + file(GLOB_RECURSE GLOG_LIB_LIST ${incubator_tvm_gpu_LIBPATH}/lib*) + install( + FILES ${GLOG_LIB_LIST} + DESTINATION ${INSTALL_LIB_DIR} + COMPONENT mindspore + ) install( DIRECTORY - ${CMAKE_BINARY_DIR}/incubator-tvm/topi/python/topi - ${CMAKE_BINARY_DIR}/incubator-tvm/python/tvm + ${incubator_tvm_gpu_ROOT}/topi/python/topi + ${incubator_tvm_gpu_ROOT}/python/tvm DESTINATION ${INSTALL_PY_DIR}/../_akg COMPONENT mindspore ) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 894a0de1b8..f0a5dc594c 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -205,8 +205,8 @@ set(MS_FIND_NO_DEFAULT_PATH ${MS_FIND_NO_DEFAULT_PATH} PARENT_SCOPE) function(mindspore_add_pkg pkg_name ) set(options ) - set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH) - set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES) + set(oneValueArgs URL MD5 GIT_REPOSITORY GIT_TAG VER EXE DIR HEAD_ONLY CMAKE_PATH RELEASE LIB_PATH CUSTOM_CMAKE) + set(multiValueArgs CMAKE_OPTION LIBS PRE_CONFIGURE_COMMAND CONFIGURE_COMMAND BUILD_OPTION INSTALL_INCS INSTALL_LIBS PATCHES SUBMODULES SOURCEMODULES) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) if (NOT PKG_LIB_PATH) @@ -270,11 +270,17 @@ function(mindspore_add_pkg pkg_name ) endif () if (NOT PKG_DIR) - if (PKG_GIT_REPOSITORY) - __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5}) - else() + if (PKG_GIT_REPOSITORY) + __download_pkg_with_git(${pkg_name} ${PKG_GIT_REPOSITORY} ${PKG_GIT_TAG} ${PKG_MD5}) + else() __download_pkg(${pkg_name} ${PKG_URL} ${PKG_MD5}) - endif() + endif() + foreach(_SUBMODULE_FILE ${PKG_SUBMODULES}) + STRING( REGEX REPLACE "(.+)_(.+)" "\\1" _SUBMODEPATH ${_SUBMODULE_FILE}) + STRING( REGEX REPLACE "(.+)/(.+)" "\\2" _SUBMODENAME ${_SUBMODEPATH}) + file(GLOB ${pkg_name}_INSTALL_SUBMODULE ${_SUBMODULE_FILE}/*) + file(COPY ${${pkg_name}_INSTALL_SUBMODULE} DESTINATION ${${pkg_name}_SOURCE_DIR}/3rdparty/${_SUBMODENAME}) + endforeach (_SUBMODULE_FILE) else() set(${pkg_name}_SOURCE_DIR ${PKG_DIR}) endif () @@ -294,12 +300,20 @@ function(mindspore_add_pkg pkg_name ) message(FATAL_ERROR "Failed patch: ${_LF_PATCH_FILE}") endif() endforeach(_PATCH_FILE) - + foreach(_SOURCE_DIR ${PKG_SOURCEMODULES}) + file(GLOB ${pkg_name}_INSTALL_SOURCE ${${pkg_name}_SOURCE_DIR}/${_SOURCE_DIR}/*) + file(COPY ${${pkg_name}_INSTALL_SOURCE} DESTINATION ${${pkg_name}_BASE_DIR}/${_SOURCE_DIR}/) + endforeach (_SUBMODULE_FILE) file(LOCK ${${pkg_name}_BASE_DIR} DIRECTORY GUARD FUNCTION RESULT_VARIABLE ${pkg_name}_LOCK_RET TIMEOUT 600) if(NOT ${pkg_name}_LOCK_RET EQUAL "0") message(FATAL_ERROR "error! when try lock ${${pkg_name}_BASE_DIR} : ${${pkg_name}_LOCK_RET}") endif() + if (PKG_CUSTOM_CMAKE) + file(GLOB ${pkg_name}_cmake ${PKG_CUSTOM_CMAKE}/CMakeLists.txt) + file(COPY ${${pkg_name}_cmake} DESTINATION ${${pkg_name}_SOURCE_DIR}) + endif () + if(${pkg_name}_SOURCE_DIR) if (PKG_HEAD_ONLY) file(GLOB ${pkg_name}_SOURCE_SUBDIRS ${${pkg_name}_SOURCE_DIR}/*) diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 9b615b0dad..eb33de1c4b 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -394,117 +394,6 @@ if(USE_GLOG) target_link_libraries(_c_expression PRIVATE mindspore::glog) endif() -if(ENABLE_GPU) - execute_process(COMMAND bash ${CMAKE_SOURCE_DIR}/third_party/apply_patches.sh - ${CMAKE_BINARY_DIR} - ${dlpack_DIRPATH} - ${dmlc_core_DIRPATH} - ${rang_DIRPATH} - ${incubator_tvm_gpu_DIRPATH}) - set(TVM_DIR "${CMAKE_BINARY_DIR}/incubator-tvm") - # Utility functions - include(${TVM_DIR}/cmake/util/Util.cmake) - include(${TVM_DIR}/cmake/util/FindCUDA.cmake) - - # include directories - include_directories(AFTER "${TVM_DIR}/include") - include_directories(AFTER "${TVM_DIR}/src") - include_directories(AFTER "${TVM_DIR}") - include_directories(AFTER "${TVM_DIR}/src/schedule") - - include_directories(AFTER "${TVM_DIR}/3rdparty/dmlc-core/include") - include_directories(AFTER "${TVM_DIR}/3rdparty/dlpack/include") - include_directories(AFTER "${TVM_DIR}/3rdparty/compiler-rt") - include_directories(AFTER "${TVM_DIR}/3rdparty/rang/include") - - # lib contain dlopen and dlclose - set(TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) - - # add source group - file(GLOB_RECURSE GROUP_SOURCE "${TVM_DIR}/src/*.cc" "src/*.cc") - file(GLOB_RECURSE GROUP_INCLUDE "${TVM_DIR}/src/*.h" - "${TVM_DIR}/include/*.h" "src/*.h" "include/*.h") - assign_source_group("Source" ${GROUP_SOURCE}) - assign_source_group("Include" ${GROUP_INCLUDE}) - - file(GLOB COMPILER_SRCS - "pre_activate/gpu/*.cc" - ${TVM_DIR}/src/api/*.cc - ${TVM_DIR}/src/arithmetic/*.cc - ${TVM_DIR}/src/autotvm/*.cc - ${TVM_DIR}/src/codegen/*.cc - ${TVM_DIR}/src/lang/*.cc - ${TVM_DIR}/src/pass/*.cc - ${TVM_DIR}/src/op/*.cc - ${TVM_DIR}/src/node/*.cc - ${TVM_DIR}/src/schedule/*.cc - ${TVM_DIR}/src/runtime/*.cc - ${TVM_DIR}/src/runtime/vm/*.cc - ${TVM_DIR}/src/runtime/vm/profiler/*.cc - ${TVM_DIR}/src/codegen/stackvm/*.cc) - - file(GLOB_RECURSE RELAY_SRCS ${TVM_DIR}/src/relay/*.cc) - list(APPEND COMPILER_SRCS ${RELAY_SRCS}) - - file(GLOB DATATYPE_SRCS ${TVM_DIR}/src/codegen/datatype/*.cc) - list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) - - file(GLOB COMPILER_VERILOG_SRCS ${TVM_DIR}/src/codegen/verilog/*.cc) - list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) - - file(GLOB TOPI_SRCS ${TVM_DIR}/topi/src/*.cc) - - file(GLOB RUNTIME_SRCS - ${TVM_DIR}/src/runtime/*.cc - ${TVM_DIR}/src/runtime/vm/*.cc - ${TVM_DIR}/src/runtime/stub/*.cc - ${TVM_DIR}/src/runtime/stackvm/*.cc) - - - file(GLOB COMPILER_OFF_SRCS - ${TVM_DIR}/src/codegen/opt/build_*_off.cc) - set(USE_CUDA "OFF") - if(ENABLE_GPU) - list(REMOVE_ITEM COMPILER_OFF_SRCS - ${TVM_DIR}/src/codegen/opt/build_cuda_off.cc) - set(USE_CUDA "ON") - endif() - list(APPEND COMPILER_SRCS ${COMPILER_OFF_SRCS}) - # Module rules - include(${TVM_DIR}/cmake/modules/CUDA.cmake) - - set(CMAKE_C_FLAGS_AKG -pipe -Wall -fPIC -fstack-protector-all) - set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) - - set(CMAKE_CXX_FLAGS_AKG -std=c++11 -pipe -Wall -fPIC -fstack-protector-all) - set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) - - if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") - message("-- Build in Debug mode") - set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O0 -g -rdynamic) - set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O0 -g -rdynamic) - else() - message("-- Build in Release mode") - set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O2 -Werror) - set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O2 -Werror) - endif() - if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION - VERSION_GREATER 7.0) - set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -faligned-new) - endif() - - add_library(akg OBJECT ${COMPILER_SRCS} ${RUNTIME_SRCS} ${TOPI_SRCS}) - - target_link_libraries(akg ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) - target_compile_options(akg PRIVATE - $<$:${CMAKE_C_FLAGS_AKG}> - $<$:${CMAKE_CXX_FLAGS_AKG}>) - target_include_directories(akg PRIVATE "${TVM_DIR}/topi/include") - - add_dependencies(_c_expression akg) - target_link_libraries(_c_expression PRIVATE akg) -endif() - if(ENABLE_DUMP_PROTO) target_link_libraries(_c_expression PRIVATE mindspore::protobuf) endif() diff --git a/third_party/patch/incubator-tvm/CMakeLists.txt b/third_party/patch/incubator-tvm/CMakeLists.txt new file mode 100644 index 0000000000..d8964579cd --- /dev/null +++ b/third_party/patch/incubator-tvm/CMakeLists.txt @@ -0,0 +1,100 @@ +cmake_minimum_required(VERSION 3.2) +project(tvm C CXX) +set(TVM_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +# Utility functions +include(${TVM_DIR}/cmake/util/Util.cmake) +include(${TVM_DIR}/cmake/util/FindCUDA.cmake) + +# include directories +include_directories(AFTER "${TVM_DIR}/include") +include_directories(AFTER "${TVM_DIR}/src") +include_directories(AFTER "${TVM_DIR}") +include_directories(AFTER "${TVM_DIR}/src/schedule") + +include_directories(AFTER "${TVM_DIR}/3rdparty/dmlc-core/include") +include_directories(AFTER "${TVM_DIR}/3rdparty/dlpack/include") +include_directories(AFTER "${TVM_DIR}/3rdparty/compiler-rt") +include_directories(AFTER "${TVM_DIR}/3rdparty/rang/include") + +# lib contain dlopen and dlclose +set(TVM_RUNTIME_LINKER_LIBS ${CMAKE_DL_LIBS}) + +# add source group +file(GLOB_RECURSE GROUP_SOURCE "${TVM_DIR}/src/*.cc" "src/*.cc") +file(GLOB_RECURSE GROUP_INCLUDE "${TVM_DIR}/src/*.h" + "${TVM_DIR}/include/*.h" "src/*.h" "include/*.h") +assign_source_group("Source" ${GROUP_SOURCE}) +assign_source_group("Include" ${GROUP_INCLUDE}) + +file(GLOB COMPILER_SRCS + "pre_activate/gpu/*.cc" + ${TVM_DIR}/src/api/*.cc + ${TVM_DIR}/src/arithmetic/*.cc + ${TVM_DIR}/src/autotvm/*.cc + ${TVM_DIR}/src/codegen/*.cc + ${TVM_DIR}/src/lang/*.cc + ${TVM_DIR}/src/pass/*.cc + ${TVM_DIR}/src/op/*.cc + ${TVM_DIR}/src/node/*.cc + ${TVM_DIR}/src/schedule/*.cc + ${TVM_DIR}/src/runtime/*.cc + ${TVM_DIR}/src/runtime/vm/*.cc + ${TVM_DIR}/src/runtime/vm/profiler/*.cc + ${TVM_DIR}/src/codegen/stackvm/*.cc) + +file(GLOB_RECURSE RELAY_SRCS ${TVM_DIR}/src/relay/*.cc) +list(APPEND COMPILER_SRCS ${RELAY_SRCS}) + +file(GLOB DATATYPE_SRCS ${TVM_DIR}/src/codegen/datatype/*.cc) +list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) + +file(GLOB COMPILER_VERILOG_SRCS ${TVM_DIR}/src/codegen/verilog/*.cc) +list(APPEND COMPILER_SRCS ${COMPILER_VERILOG_SRCS}) + +file(GLOB TOPI_SRCS ${TVM_DIR}/topi/src/*.cc) + +file(GLOB RUNTIME_SRCS + ${TVM_DIR}/src/runtime/*.cc + ${TVM_DIR}/src/runtime/vm/*.cc + ${TVM_DIR}/src/runtime/stub/*.cc + ${TVM_DIR}/src/runtime/stackvm/*.cc) + + +file(GLOB COMPILER_OFF_SRCS + ${TVM_DIR}/src/codegen/opt/build_*_off.cc) + +list(REMOVE_ITEM COMPILER_OFF_SRCS + ${TVM_DIR}/src/codegen/opt/build_cuda_off.cc) +set(USE_CUDA "ON") +list(APPEND COMPILER_SRCS ${COMPILER_OFF_SRCS}) +# Module rules +include(${TVM_DIR}/cmake/modules/CUDA.cmake) + +set(CMAKE_C_FLAGS_AKG -pipe -Wall -fPIC -fstack-protector-all) +set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) + +set(CMAKE_CXX_FLAGS_AKG -std=c++11 -pipe -Wall -fPIC -fstack-protector-all) +set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -Wl,-z,relro,-z,now,-z,noexecstack) + +if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug") + message("-- Build in Debug mode") + set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O0 -g -rdynamic) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O0 -g -rdynamic) +else() + message("-- Build in Release mode") + set(CMAKE_C_FLAGS_AKG ${CMAKE_C_FLAGS_AKG} -O2 -Werror) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -O2 -Werror) +endif() +if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION + VERSION_GREATER 7.0) + set(CMAKE_CXX_FLAGS_AKG ${CMAKE_CXX_FLAGS_AKG} -faligned-new) +endif() + +add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS} ${TOPI_SRCS}) + +target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) +target_compile_options(tvm PRIVATE + $<$:${CMAKE_C_FLAGS_AKG}> + $<$:${CMAKE_CXX_FLAGS_AKG}>) +target_include_directories(tvm PRIVATE "${TVM_DIR}/topi/include") +install(TARGETS tvm) \ No newline at end of file diff --git a/third_party/patch/incubator-tvm/find_library.patch b/third_party/patch/incubator-tvm/find_library.patch index e54df2c7cf..f7b2f9af0a 100644 --- a/third_party/patch/incubator-tvm/find_library.patch +++ b/third_party/patch/incubator-tvm/find_library.patch @@ -18,11 +18,11 @@ - lib_path = libinfo.find_lib_path() + """Load library by searching possible path.""" + pwd = os.path.dirname(os.path.realpath(__file__)) -+ path = os.path.realpath(pwd+"/../../../mindspore") ++ path = os.path.realpath(pwd+"/../../../mindspore/lib") + lib_path = [] + files = os.listdir(path) + for f in files: -+ if f.startswith("_c_expression.") and f.endswith(".so"): ++ if f.startswith("libtvm.") and f.endswith(".so"): + lib_path.append(path+"/"+f) + break + if not lib_path: @@ -56,11 +56,11 @@ diff -Npur tvm/topi/python/topi/cpp/impl.py tvm_new/topi/python/topi/cpp/impl.py - return None, None + """Load library by searching possible path.""" + pwd = os.path.dirname(os.path.realpath(__file__)) -+ path = os.path.realpath(pwd+"/../../../mindspore") ++ path = os.path.realpath(pwd+"/../../../mindspore/lib") + lib_path = [] + files = os.listdir(path) + for f in files: -+ if f.startswith("_c_expression.") and f.endswith(".so"): ++ if f.startswith("libtvm.") and f.endswith(".so"): + lib_path.append(path+"/"+f) + break + if not lib_path: From f4bae5f364aa99727c2b27677336f121a92af634 Mon Sep 17 00:00:00 2001 From: jonwe Date: Tue, 21 Apr 2020 09:12:13 -0400 Subject: [PATCH 062/142] optimize mindrecord writer performance --- example/convert_to_mindrecord/README.md | 46 +++++ .../imagenet/__init__.py | 0 .../convert_to_mindrecord/imagenet/mr_api.py | 122 ++++++++++++ example/convert_to_mindrecord/run_imagenet.sh | 8 + example/convert_to_mindrecord/run_template.sh | 6 + .../template/__init__.py | 0 .../convert_to_mindrecord/template/mr_api.py | 73 +++++++ example/convert_to_mindrecord/writer.py | 149 ++++++++++++++ .../ccsrc/mindrecord/common/shard_pybind.cc | 9 +- .../ccsrc/mindrecord/include/shard_header.h | 4 + .../ccsrc/mindrecord/include/shard_writer.h | 37 +++- .../mindrecord/io/shard_index_generator.cc | 3 + mindspore/ccsrc/mindrecord/io/shard_writer.cc | 188 ++++++++++++++++-- .../ccsrc/mindrecord/meta/shard_header.cc | 38 ++++ mindspore/mindrecord/filewriter.py | 15 +- mindspore/mindrecord/shardwriter.py | 5 +- 16 files changed, 668 insertions(+), 35 deletions(-) create mode 100644 example/convert_to_mindrecord/README.md create mode 100644 example/convert_to_mindrecord/imagenet/__init__.py create mode 100644 example/convert_to_mindrecord/imagenet/mr_api.py create mode 100644 example/convert_to_mindrecord/run_imagenet.sh create mode 100644 example/convert_to_mindrecord/run_template.sh create mode 100644 example/convert_to_mindrecord/template/__init__.py create mode 100644 example/convert_to_mindrecord/template/mr_api.py create mode 100644 example/convert_to_mindrecord/writer.py diff --git a/example/convert_to_mindrecord/README.md b/example/convert_to_mindrecord/README.md new file mode 100644 index 0000000000..8d3b25e311 --- /dev/null +++ b/example/convert_to_mindrecord/README.md @@ -0,0 +1,46 @@ +# MindRecord generating guidelines + + + +- [MindRecord generating guidelines](#mindrecord-generating-guidelines) + - [Create work space](#create-work-space) + - [Implement data generator](#implement-data-generator) + - [Run data generator](#run-data-generator) + + + +## Create work space + +Assume the dataset name is 'xyz' +* Create work space from template + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + cp -r template xyz + ``` + +## Implement data generator + +Edit dictionary data generator +* Edit file + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + vi xyz/mr_api.py + ``` + + Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented +- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. +- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() + + +Tricky for parallel run +- For imagenet, one directory can be a task. +- For TFRecord with multiple files, each file can be a task. +- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) + + +## Run data generator +* run python script + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + python writer.py --mindrecord_script imagenet [...] + ``` diff --git a/example/convert_to_mindrecord/imagenet/__init__.py b/example/convert_to_mindrecord/imagenet/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/imagenet/mr_api.py b/example/convert_to_mindrecord/imagenet/mr_api.py new file mode 100644 index 0000000000..e569b489b5 --- /dev/null +++ b/example/convert_to_mindrecord/imagenet/mr_api.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import os +import pickle + +######## mindrecord_schema begin ########## +mindrecord_schema = {"label": {"type": "int64"}, + "data": {"type": "bytes"}, + "file_name": {"type": "string"}} +######## mindrecord_schema end ########## + +######## Frozen code begin ########## +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: + ARG_LIST = pickle.load(mindrecord_argument_file_handle) +######## Frozen code end ########## + +parser = argparse.ArgumentParser(description='Mind record imagenet example') +parser.add_argument('--label_file', type=str, default="", help='label file') +parser.add_argument('--image_dir', type=str, default="", help='images directory') + +######## Frozen code begin ########## +args = parser.parse_args(ARG_LIST) +print(args) +######## Frozen code end ########## + + +def _user_defined_private_func(): + """ + Internal function for tasks list + + Return: + tasks list + """ + if not os.path.exists(args.label_file): + raise IOError("map file {} not exists".format(args.label_file)) + + label_dict = {} + with open(args.label_file) as file_handle: + line = file_handle.readline() + while line: + labels = line.split(" ") + label_dict[labels[1]] = labels[0] + line = file_handle.readline() + # get all the dir which are n02087046, n02094114, n02109525 + dir_paths = {} + for item in label_dict: + real_path = os.path.join(args.image_dir, label_dict[item]) + if not os.path.isdir(real_path): + print("{} dir is not exist".format(real_path)) + continue + dir_paths[item] = real_path + + if not dir_paths: + print("not valid image dir in {}".format(args.image_dir)) + return {}, {} + + dir_list = [] + for label in dir_paths: + dir_list.append(label) + return dir_list, dir_paths + + +dir_list_global, dir_paths_global = _user_defined_private_func() + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return len(dir_list_global) + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + + # get the filename, label and image binary as a dict + label = dir_list_global[task_id] + for item in os.listdir(dir_paths_global[label]): + file_name = os.path.join(dir_paths_global[label], item) + if not item.endswith("JPEG") and not item.endswith( + "jpg") and not item.endswith("jpeg"): + print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) + continue + data = {} + data["file_name"] = str(file_name) + data["label"] = int(label) + + # get the image data + image_file = open(file_name, "rb") + image_bytes = image_file.read() + image_file.close() + data["data"] = image_bytes + yield data diff --git a/example/convert_to_mindrecord/run_imagenet.sh b/example/convert_to_mindrecord/run_imagenet.sh new file mode 100644 index 0000000000..11f5dcff75 --- /dev/null +++ b/example/convert_to_mindrecord/run_imagenet.sh @@ -0,0 +1,8 @@ +#!/bin/bash +rm /tmp/imagenet/mr/* + +python writer.py --mindrecord_script imagenet \ +--mindrecord_file "/tmp/imagenet/mr/m" \ +--mindrecord_partitions 16 \ +--label_file "/tmp/imagenet/label.txt" \ +--image_dir "/tmp/imagenet/jpeg" diff --git a/example/convert_to_mindrecord/run_template.sh b/example/convert_to_mindrecord/run_template.sh new file mode 100644 index 0000000000..a4c5142c00 --- /dev/null +++ b/example/convert_to_mindrecord/run_template.sh @@ -0,0 +1,6 @@ +#!/bin/bash +rm /tmp/template/* + +python writer.py --mindrecord_script template \ +--mindrecord_file "/tmp/template/m" \ +--mindrecord_partitions 4 diff --git a/example/convert_to_mindrecord/template/__init__.py b/example/convert_to_mindrecord/template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/template/mr_api.py b/example/convert_to_mindrecord/template/mr_api.py new file mode 100644 index 0000000000..3f7d7dddf0 --- /dev/null +++ b/example/convert_to_mindrecord/template/mr_api.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import pickle + +# ## Parse argument + +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line + ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line +parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line + +# ## Your arguments below +# parser.add_argument(...) + +args = parser.parse_args(ARG_LIST) # Do NOT change this line +print(args) # Do NOT change this line + + +# ## Default mindrecord vars. Comment them unless default value has to be changed. +# mindrecord_index_fields = ['label'] +# mindrecord_header_size = 1 << 24 +# mindrecord_page_size = 1 << 25 + + +# define global vars here if necessary + + +# ####### Your code below ########## +mindrecord_schema = {"label": {"type": "int32"}} + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return 1 + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + print("task is {}".format(task_id)) + for i in range(256): + data = {} + data['label'] = i + yield data diff --git a/example/convert_to_mindrecord/writer.py b/example/convert_to_mindrecord/writer.py new file mode 100644 index 0000000000..0a9ad5c86a --- /dev/null +++ b/example/convert_to_mindrecord/writer.py @@ -0,0 +1,149 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +######################## write mindrecord example ######################## +Write mindrecord by data dictionary: +python writer.py --mindrecord_script /YourScriptPath ... +""" +import argparse +import os +import pickle +import time +from importlib import import_module +from multiprocessing import Pool + +from mindspore.mindrecord import FileWriter + + +def _exec_task(task_id, parallel_writer=True): + """ + Execute task with specified task id + """ + print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) + imagenet_iter = mindrecord_dict_data(task_id) + batch_size = 2048 + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data_list.append(imagenet_iter.__next__()) + transform_count += 1 + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + except StopIteration: + if data_list: + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Mind record writer') + parser.add_argument('--mindrecord_script', type=str, default="template", + help='path where script is saved') + + parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", + help='written file name prefix') + + parser.add_argument('--mindrecord_partitions', type=int, default=1, + help='number of written files') + + parser.add_argument('--mindrecord_workers', type=int, default=8, + help='number of parallel workers') + + args = parser.parse_known_args() + + args, other_args = parser.parse_known_args() + + print(args) + print(other_args) + + with open('mr_argument.pickle', 'wb') as file_handle: + pickle.dump(other_args, file_handle) + + try: + mr_api = import_module(args.mindrecord_script + '.mr_api') + except ModuleNotFoundError: + raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) + + num_tasks = mr_api.mindrecord_task_number() + + print("Write mindrecord ...") + + mindrecord_dict_data = mr_api.mindrecord_dict_data + + # get number of files + writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) + + start_time = time.time() + + # set the header size + try: + header_size = mr_api.mindrecord_header_size + writer.set_header_size(header_size) + except AttributeError: + print("Default header size: {}".format(1 << 24)) + + # set the page size + try: + page_size = mr_api.mindrecord_page_size + writer.set_page_size(page_size) + except AttributeError: + print("Default page size: {}".format(1 << 25)) + + # get schema + try: + mindrecord_schema = mr_api.mindrecord_schema + except AttributeError: + raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") + + # create the schema + writer.add_schema(mindrecord_schema, "mindrecord_schema") + + # add the index + try: + index_fields = mr_api.mindrecord_index_fields + writer.add_index(index_fields) + except AttributeError: + print("Default index fields: all simple fields are indexes.") + + writer.open_and_set_header() + + task_list = list(range(num_tasks)) + + # set number of workers + num_workers = args.mindrecord_workers + + if num_tasks < 1: + num_tasks = 1 + + if num_workers > num_tasks: + num_workers = num_tasks + + if num_tasks > 1: + with Pool(num_workers) as p: + p.map(_exec_task, task_list) + else: + _exec_task(0, False) + + ret = writer.commit() + + os.remove("{}".format("mr_argument.pickle")) + + end_time = time.time() + print("--------------------------------------------") + print("END. Total time: {}".format(end_time - start_time)) + print("--------------------------------------------") diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 338a17ac2d..8718e9b871 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -75,12 +75,9 @@ void BindShardWriter(py::module *m) { .def("set_header_size", &ShardWriter::set_header_size) .def("set_page_size", &ShardWriter::set_page_size) .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", - (MSRStatus(ShardWriter::*)(std::map> &, vector> &, bool)) & - ShardWriter::WriteRawData) - .def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map> &, - std::map> &, bool)) & - ShardWriter::WriteRawData) + .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, + vector> &, bool, bool)) & + ShardWriter::WriteRawData) .def("commit", &ShardWriter::Commit); } diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index ca4d3bd66f..70cfcdb6b7 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -121,6 +121,10 @@ class ShardHeader { std::vector SerializeHeader(); + MSRStatus PagesToFile(const std::string dump_file_name); + + MSRStatus FileToPages(const std::string dump_file_name); + private: MSRStatus InitializeHeader(const std::vector &headers); diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 6a22f07700..78a434fc97 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -18,6 +18,7 @@ #define MINDRECORD_INCLUDE_SHARD_WRITER_H_ #include +#include #include #include #include @@ -87,7 +88,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -95,7 +96,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -103,7 +104,8 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true); + std::map> &blob_data, bool sign = true, + bool parallel_writer = false); private: /// \brief write shard header data to disk @@ -201,7 +203,34 @@ class ShardWriter { MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, std::map &err_raw_data); + /// \brief Lock writer and save pages info + int LockWriter(bool parallel_writer = false); + + /// \brief Unlock writer and save pages info + MSRStatus UnlockWriter(int fd, bool parallel_writer = false); + + /// \brief Check raw data before writing + MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, + bool sign, int *schema_count, int *row_count); + + /// \brief Get full path from file name + MSRStatus GetFullPathFromFileName(const std::vector &paths); + + /// \brief Open files + MSRStatus OpenDataFiles(bool append); + + /// \brief Remove lock file + MSRStatus RemoveLockFile(); + + /// \brief Remove lock file + MSRStatus InitLockFile(); + private: + const std::string kLockFileSuffix = "_Locker"; + const std::string kPageFileSuffix = "_Pages"; + std::string lock_file_; // lock file for parallel run + std::string pages_file_; // temporary file of pages info for parallel run + int shard_count_; // number of files uint64_t header_size_; // header size uint64_t page_size_; // page size @@ -211,7 +240,7 @@ class ShardWriter { std::vector raw_data_size_; // Raw data size std::vector blob_data_size_; // Blob data size - std::vector file_paths_; // file paths + std::vector file_paths_; // file paths std::vector> file_streams_; // file handles std::shared_ptr shard_header_; // shard headers diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index 5a5cd7cbf3..dc2743cdc7 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -520,13 +520,16 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std for (int raw_page_id : raw_page_ids) { auto sql = GenerateRawSQL(fields_); if (sql.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw SQL failed"; return FAILED; } auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); if (data.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw data failed"; return FAILED; } if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { + MS_LOG(ERROR) << "Execute SQL failed"; return FAILED; } MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 864e6697d0..ac95e622c9 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -40,17 +40,7 @@ ShardWriter::~ShardWriter() { } } -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - +MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { // Get full path from file name for (const auto &path : paths) { if (!CheckIsValidUtf8(path)) { @@ -60,7 +50,7 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) char resolved_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func failed"; + MS_LOG(ERROR) << "Secure func failed"; return FAILED; } #if defined(_WIN32) || defined(_WIN64) @@ -82,7 +72,10 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) #endif file_paths_.emplace_back(string(resolved_path)); } + return SUCCESS; +} +MSRStatus ShardWriter::OpenDataFiles(bool append) { // Open files for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); @@ -116,6 +109,67 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) return SUCCESS; } +MSRStatus ShardWriter::RemoveLockFile() { + // Remove temporary file + int ret = std::remove(pages_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove page file."; + } + + ret = std::remove(lock_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove lock file."; + } + return SUCCESS; +} + +MSRStatus ShardWriter::InitLockFile() { + if (file_paths_.size() == 0) { + MS_LOG(ERROR) << "File path not initialized."; + return FAILED; + } + + lock_file_ = file_paths_[0] + kLockFileSuffix; + pages_file_ = file_paths_[0] + kPageFileSuffix; + + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + + // Get full path from file name + if (GetFullPathFromFileName(paths) == FAILED) { + MS_LOG(ERROR) << "Get full path from file name failed."; + return FAILED; + } + + // Open files + if (OpenDataFiles(append) == FAILED) { + MS_LOG(ERROR) << "Open data files failed."; + return FAILED; + } + + // Init lock file + if (InitLockFile() == FAILED) { + MS_LOG(ERROR) << "Init lock file failed."; + return FAILED; + } + return SUCCESS; +} + MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (!IsLegalFile(path)) { return FAILED; @@ -143,11 +197,28 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { } MSRStatus ShardWriter::Commit() { + // Read pages file + std::ifstream page_file(pages_file_.c_str()); + if (page_file.good()) { + page_file.close(); + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return FAILED; + } + } + if (WriteShardHeader() == FAILED) { MS_LOG(ERROR) << "Write metadata failed"; return FAILED; } MS_LOG(INFO) << "Write metadata successfully."; + + // Remove lock file + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove lock file failed."; + return FAILED; + } + return SUCCESS; } @@ -455,15 +526,65 @@ void ShardWriter::FillArray(int start, int end, std::map> } } -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign) { +int ShardWriter::LockWriter(bool parallel_writer) { + if (!parallel_writer) { + return 0; + } + const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); + if (fd >= 0) { + flock(fd, LOCK_EX); + } else { + MS_LOG(ERROR) << "Shard writer failed when locking file"; + return -1; + } + + // Open files + file_streams_.clear(); + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); + if (fs->fail()) { + MS_LOG(ERROR) << "File could not opened"; + return -1; + } + file_streams_.push_back(fs); + } + + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return -1; + } + return fd; +} + +MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { + if (!parallel_writer) { + return SUCCESS; + } + + if (shard_header_->PagesToFile(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Write pages to file failed"; + return FAILED; + } + + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } + + flock(fd, LOCK_UN); + close(fd); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, + std::vector> &blob_data, bool sign, int *schema_count, + int *row_count) { // check the free disk size auto st_space = GetDiskSize(file_paths_[0], kFreeSize); if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { MS_LOG(ERROR) << "IO error / there is no free disk to be used"; return FAILED; } - // Add 4-bytes dummy blob data if no any blob fields if (blob_data.size() == 0 && raw_data.size() > 0) { blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); @@ -479,10 +600,29 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d MS_LOG(ERROR) << "Validate raw data failed"; return FAILED; } + *schema_count = std::get<1>(v); + *row_count = std::get<2>(v); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, bool parallel_writer) { + // Lock Writer if loading data parallel + int fd = LockWriter(parallel_writer); + if (fd < 0) { + MS_LOG(ERROR) << "Lock writer failed"; + return FAILED; + } // Get the count of schemas and rows - int schema_count = std::get<1>(v); - int row_count = std::get<2>(v); + int schema_count = 0; + int row_count = 0; + + // Serialize raw data + if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { + MS_LOG(ERROR) << "Check raw data failed"; + return FAILED; + } if (row_count == kInt0) { MS_LOG(INFO) << "Raw data size is 0."; @@ -516,11 +656,17 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d } MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; + if (UnlockWriter(fd, parallel_writer) == FAILED) { + MS_LOG(ERROR) << "Unlock writer failed"; + return FAILED; + } + return SUCCESS; } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign) { + std::map> &blob_data, bool sign, + bool parallel_writer) { std::map> raw_data_json; std::map> blob_data_json; @@ -554,11 +700,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; return FAILED; } - return WriteRawData(raw_data_json, bin_blob_data, sign); + return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign) { + vector> &blob_data, bool sign, bool parallel_writer) { std::map> raw_data_json; (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), [](const std::pair> &pair) { @@ -568,7 +714,7 @@ MSRStatus ShardWriter::WriteRawData(std::map> [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); return std::make_pair(pair.first, std::move(json_raw_data)); }); - return WriteRawData(raw_data_json, blob_data, sign); + return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); } MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 57b2e5fa9e..26008e3ca9 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -677,5 +677,43 @@ std::pair, MSRStatus> ShardHeader::GetStatisticByID( } return std::make_pair(statistics_.at(statistic_id), SUCCESS); } + +MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { + // write header content to file, dump whatever is in the file before + std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); + if (page_out_handle.fail()) { + MS_LOG(ERROR) << "Failed in opening page file"; + return FAILED; + } + + auto pages = SerializePage(); + for (const auto &shard_pages : pages) { + page_out_handle << shard_pages << "\n"; + } + + page_out_handle.close(); + return SUCCESS; +} + +MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { + for (auto &v : pages_) { // clean pages + v.clear(); + } + // attempt to open the file contains the page in json + std::ifstream page_in_handle(dump_file_name.c_str()); + + if (!page_in_handle.good()) { + MS_LOG(INFO) << "No page file exists."; + return SUCCESS; + } + + std::string line; + while (std::getline(page_in_handle, line)) { + ParsePage(json::parse(line)); + } + + page_in_handle.close(); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 90bca48038..62bcc2df79 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -200,13 +200,24 @@ class FileWriter: raw_data.pop(i) logger.warning(v) - def write_raw_data(self, raw_data): + def open_and_set_header(self): + """ + Open writer and set header + + """ + if not self._writer.is_open: + self._writer.open(self._paths) + if not self._writer.get_shard_header(): + self._writer.set_shard_header(self._header) + + def write_raw_data(self, raw_data, parallel_writer=False): """ Write raw data and generate sequential pair of MindRecord File and \ validate data based on predefined schema by default. Args: raw_data (list[dict]): List of raw data. + parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). Raises: ParamTypeError: If index field is invalid. @@ -225,7 +236,7 @@ class FileWriter: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') self._verify_based_on_schema(raw_data) - return self._writer.write_raw_data(raw_data, True) + return self._writer.write_raw_data(raw_data, True, parallel_writer) def set_header_size(self, header_size): """ diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0ef23d4ce6..0913201861 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -135,7 +135,7 @@ class ShardWriter: def get_shard_header(self): return self._header - def write_raw_data(self, data, validate=True): + def write_raw_data(self, data, validate=True, parallel_writer=False): """ Write raw data of cv dataset. @@ -145,6 +145,7 @@ class ShardWriter: Args: data (list[dict]): List of raw data. validate (bool, optional): verify data according schema if it equals to True. + parallel_writer (bool, optional): Load data parallel if it equals to True. Returns: MSRStatus, SUCCESS or FAILED. @@ -165,7 +166,7 @@ class ShardWriter: if row_raw: raw_data.append(row_raw) raw_data = {0: raw_data} if raw_data else {} - ret = self._writer.write_raw_data(raw_data, blob_data, validate) + ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer) if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to write dataset.") raise MRMWriteDatasetError From 371b3338b6cb8789294e0e60a07efc9b06dadc39 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 21 Apr 2020 20:22:01 +0800 Subject: [PATCH 063/142] support one node communication group --- mindspore/communication/_comm_helper.py | 4 ++-- tests/ut/python/communication/test_management_api.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/communication/_comm_helper.py b/mindspore/communication/_comm_helper.py index 099c8cfc2d..508aa2e7a9 100644 --- a/mindspore/communication/_comm_helper.py +++ b/mindspore/communication/_comm_helper.py @@ -334,8 +334,8 @@ def _create_group_helper(group, rank_ids, backend): if not isinstance(rank_ids, list): raise TypeError("Rank_ids {} should be list".format(rank_ids)) rank_size = len(rank_ids) - if rank_size < 2: - raise ValueError("Rank_ids size {} should be large than 1".format(rank_size)) + if rank_size < 1: + raise ValueError("Rank_ids size {} should be large than 0".format(rank_size)) if len(rank_ids) - len(list(set(rank_ids))) > 0: raise ValueError("List rank_ids in Group {} has duplicate data!".format(group)) hccl.create_group(group, rank_size, rank_ids) diff --git a/tests/ut/python/communication/test_management_api.py b/tests/ut/python/communication/test_management_api.py index c455c5491b..d624c5ab59 100644 --- a/tests/ut/python/communication/test_management_api.py +++ b/tests/ut/python/communication/test_management_api.py @@ -99,7 +99,7 @@ def test_raise_error_funcs(): assert has_raise_error(create_backend, 'nccl') is False assert has_raise_error(get_group_size_int, 123) is True assert has_raise_error(create_group0, (0,1)) is True - assert has_raise_error(create_group1, [0]) is True + assert has_raise_error(create_group1, [0]) is False assert has_raise_error(create_group2, [0,0,1]) is True assert has_raise_error(create_group3, [0,1]) is True assert has_raise_error(create_group4, [0,1]) is False From c2b3360d690035237a94f5780ab65ae8e74f7afe Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Tue, 21 Apr 2020 21:21:19 +0800 Subject: [PATCH 064/142] update clang format rule --- .clang-format | 2 +- mindspore/ccsrc/common/utils.cc | 2 +- mindspore/ccsrc/common/utils.h | 12 +- .../ccsrc/dataset/kernels/image/decode_op.h | 8 +- .../image/distort_bounding_box_crop_op.cc | 10 +- .../image/distort_bounding_box_crop_op.h | 10 +- .../image/random_crop_and_resize_op.cc | 6 +- mindspore/ccsrc/debug/anf_ir_dump.h | 2 +- mindspore/ccsrc/debug/anf_ir_utils.cc | 216 +++++++------- mindspore/ccsrc/debug/anf_ir_utils.h | 66 ++--- mindspore/ccsrc/debug/draw.cc | 64 ++--- mindspore/ccsrc/debug/draw.h | 36 +-- mindspore/ccsrc/debug/dump_proto.cc | 170 +++++------ mindspore/ccsrc/debug/e2e_dump.cc | 20 +- mindspore/ccsrc/debug/e2e_dump.h | 14 +- mindspore/ccsrc/debug/info.cc | 18 +- mindspore/ccsrc/debug/info.h | 52 ++-- mindspore/ccsrc/debug/label.cc | 16 +- mindspore/ccsrc/debug/label.h | 2 +- mindspore/ccsrc/debug/trace.cc | 48 ++-- mindspore/ccsrc/debug/trace.h | 20 +- mindspore/ccsrc/debug/trace_info.cc | 2 +- mindspore/ccsrc/debug/trace_info.h | 82 +++--- .../ccsrc/device/ascend/ascend_memory_pool.cc | 6 +- .../ccsrc/device/ascend/ascend_memory_pool.h | 14 +- .../device/ascend/ascend_stream_assign.h | 64 ++--- .../device/ascend/profiling/plugin_impl.h | 8 +- .../ascend/profiling/profiling_engine_impl.cc | 4 +- .../ascend/profiling/profiling_engine_impl.h | 4 +- .../ascend/profiling/profiling_manager.cc | 16 +- mindspore/ccsrc/device/gpu/blocking_queue.h | 28 +- .../gpu/distribution/collective_init.cc | 6 +- .../ccsrc/device/gpu/gpu_device_manager.cc | 14 +- .../ccsrc/device/gpu/gpu_device_manager.h | 20 +- .../ccsrc/device/gpu/gpu_memory_allocator.cc | 6 +- .../ccsrc/device/gpu/gpu_memory_allocator.h | 12 +- .../ccsrc/device/gpu/kernel_info_setter.cc | 14 +- .../ccsrc/device/gpu/kernel_info_setter.h | 14 +- mindspore/ccsrc/gvar/typeid_manager.cc | 2 +- mindspore/ccsrc/ir/anf.cc | 26 +- mindspore/ccsrc/ir/base.h | 18 +- mindspore/ccsrc/ir/dtype.cc | 96 +++---- mindspore/ccsrc/ir/dtype.h | 44 +-- mindspore/ccsrc/ir/dtype/container.cc | 32 +-- mindspore/ccsrc/ir/dtype/container.h | 24 +- mindspore/ccsrc/ir/dtype/number.cc | 4 +- mindspore/ccsrc/ir/dtype/number.h | 4 +- mindspore/ccsrc/ir/dtype/ref.h | 2 +- mindspore/ccsrc/ir/dtype/type.cc | 26 +- mindspore/ccsrc/ir/dtype/type.h | 22 +- mindspore/ccsrc/ir/func_graph.cc | 142 ++++----- mindspore/ccsrc/ir/func_graph_cloner.cc | 170 +++++------ mindspore/ccsrc/ir/func_graph_cloner.h | 84 +++--- mindspore/ccsrc/ir/manager.cc | 230 +++++++-------- mindspore/ccsrc/ir/manager.h | 210 +++++++------- mindspore/ccsrc/ir/meta_func_graph.h | 20 +- mindspore/ccsrc/ir/meta_tensor.cc | 56 ++-- mindspore/ccsrc/ir/meta_tensor.h | 62 ++-- mindspore/ccsrc/ir/named.cc | 4 +- mindspore/ccsrc/ir/named.h | 14 +- mindspore/ccsrc/ir/primitive.cc | 26 +- mindspore/ccsrc/ir/primitive.h | 40 +-- mindspore/ccsrc/ir/scalar.h | 48 ++-- mindspore/ccsrc/ir/signature.cc | 8 +- mindspore/ccsrc/ir/signature.h | 6 +- mindspore/ccsrc/ir/value.cc | 122 ++++---- mindspore/ccsrc/ir/value.h | 90 +++--- mindspore/ccsrc/ir/visitor.h | 14 +- mindspore/ccsrc/kernel/kernel_query.cc | 8 +- mindspore/ccsrc/kernel/oplib/opinfo.h | 44 +-- mindspore/ccsrc/kernel/oplib/oplib.cc | 52 ++-- mindspore/ccsrc/kernel/oplib/oplib.h | 22 +- mindspore/ccsrc/mindspore.cc | 2 +- mindspore/ccsrc/onnx/onnx_exporter.cc | 272 +++++++++--------- .../ccsrc/operator/cc_implementations.cc | 24 +- mindspore/ccsrc/operator/cc_implementations.h | 44 +-- .../ccsrc/operator/composite/composite.cc | 168 +++++------ .../ccsrc/operator/composite/composite.h | 98 +++---- .../ccsrc/operator/composite/do_signature.cc | 32 +-- .../ccsrc/operator/composite/do_signature.h | 10 +- .../composite/list_append_operation.cc | 6 +- .../composite/list_append_operation.h | 8 +- .../ccsrc/operator/composite/unpack_call.cc | 8 +- .../ccsrc/operator/composite/unpack_call.h | 6 +- .../ccsrc/operator/composite/zip_operation.cc | 10 +- .../ccsrc/operator/composite/zip_operation.h | 8 +- mindspore/ccsrc/operator/ops.cc | 2 +- mindspore/ccsrc/operator/ops.h | 8 +- mindspore/ccsrc/operator/prim_to_function.cc | 4 +- mindspore/ccsrc/operator/prim_to_function.h | 10 +- mindspore/ccsrc/optimizer/ad/adjoint.cc | 14 +- mindspore/ccsrc/optimizer/ad/adjoint.h | 10 +- mindspore/ccsrc/optimizer/clean.cc | 84 +++--- mindspore/ccsrc/optimizer/control_depend.h | 2 +- .../optimizer/irpass/grad_var_prepare.cc | 10 +- mindspore/ccsrc/optimizer/opt.cc | 34 +-- .../allreduce_fusion/allreduce_fusion.cc | 58 ++-- .../allreduce_fusion/allreduce_fusion.h | 8 +- .../allreduce_fusion/allreduce_graph.cc | 26 +- .../allreduce_fusion/allreduce_graph.h | 10 +- .../allreduce_fusion/allreduce_node.cc | 16 +- .../allreduce_fusion/allreduce_node.h | 18 +- .../ccsrc/parallel/auto_parallel/costmodel.cc | 6 +- .../ccsrc/parallel/auto_parallel/costmodel.h | 12 +- .../auto_parallel/dp_algo_costmodel.cc | 4 +- .../auto_parallel/dp_algo_costmodel.h | 4 +- .../parallel/auto_parallel/edge_costmodel.cc | 60 ++-- .../parallel/auto_parallel/edge_costmodel.h | 54 ++-- .../parallel/auto_parallel/graph_costmodel.cc | 238 +++++++-------- .../parallel/auto_parallel/graph_costmodel.h | 74 ++--- .../auto_parallel/operator_costmodel.cc | 156 +++++----- .../auto_parallel/operator_costmodel.h | 234 +++++++-------- .../auto_parallel/rec_core/rec_parse_graph.cc | 18 +- .../auto_parallel/rec_core/rec_parse_graph.h | 16 +- mindspore/ccsrc/parallel/context.cc | 6 +- mindspore/ccsrc/parallel/context.h | 10 +- mindspore/ccsrc/parallel/costmodel_context.h | 4 +- mindspore/ccsrc/parallel/device_manager.cc | 32 +-- mindspore/ccsrc/parallel/device_manager.h | 18 +- mindspore/ccsrc/parallel/device_matrix.cc | 24 +- mindspore/ccsrc/parallel/device_matrix.h | 8 +- mindspore/ccsrc/parallel/dynamic_creator.h | 14 +- .../parallel/graph_util/generate_graph.cc | 20 +- .../parallel/graph_util/generate_graph.h | 14 +- .../parallel/graph_util/get_parallel_info.cc | 6 +- .../parallel/graph_util/get_parallel_info.h | 6 +- .../ccsrc/parallel/graph_util/graph_info.cc | 6 +- .../ccsrc/parallel/graph_util/graph_info.h | 4 +- .../ccsrc/parallel/graph_util/node_info.cc | 4 +- .../ccsrc/parallel/graph_util/node_info.h | 4 +- mindspore/ccsrc/parallel/group_manager.h | 16 +- mindspore/ccsrc/parallel/node_check.cc | 2 +- mindspore/ccsrc/parallel/node_check.h | 2 +- .../parallel/ops_info/activation_info.cc | 30 +- .../ccsrc/parallel/ops_info/activation_info.h | 80 +++--- .../ccsrc/parallel/ops_info/arithmetic_info.h | 46 +-- .../parallel/ops_info/batch_parallel_info.cc | 8 +- .../parallel/ops_info/batch_parallel_info.h | 20 +- .../ccsrc/parallel/ops_info/bias_add_info.h | 14 +- .../ops_info/comparison_function_info.h | 16 +- .../parallel/ops_info/dropout_do_mask_info.cc | 14 +- .../parallel/ops_info/dropout_do_mask_info.h | 14 +- .../ops_info/elementary_function_info.h | 14 +- .../ccsrc/parallel/ops_info/gather_v2_info.cc | 10 +- .../ccsrc/parallel/ops_info/gather_v2_info.h | 12 +- .../ccsrc/parallel/ops_info/get_next_info.cc | 16 +- .../parallel/ops_info/l2_normalize_info.cc | 4 +- .../parallel/ops_info/l2_normalize_info.h | 6 +- .../ccsrc/parallel/ops_info/layer_norm_info.h | 14 +- .../ccsrc/parallel/ops_info/loss_info.cc | 10 +- mindspore/ccsrc/parallel/ops_info/loss_info.h | 12 +- .../ccsrc/parallel/ops_info/matmul_info.cc | 26 +- .../ccsrc/parallel/ops_info/matmul_info.h | 30 +- .../ccsrc/parallel/ops_info/onehot_info.cc | 14 +- .../ccsrc/parallel/ops_info/onehot_info.h | 16 +- .../ccsrc/parallel/ops_info/operator_info.cc | 118 ++++---- .../ccsrc/parallel/ops_info/operator_info.h | 110 +++---- .../ccsrc/parallel/ops_info/prelu_info.cc | 12 +- .../ccsrc/parallel/ops_info/prelu_info.h | 14 +- .../ccsrc/parallel/ops_info/reshape_info.cc | 18 +- .../ccsrc/parallel/ops_info/reshape_info.h | 22 +- .../parallel/ops_info/tmp_identity_info.h | 12 +- .../ccsrc/parallel/ops_info/transpose_info.cc | 16 +- .../ccsrc/parallel/ops_info/transpose_info.h | 14 +- .../parallel/ops_info/virtual_dataset_info.cc | 10 +- .../parallel/ops_info/virtual_dataset_info.h | 12 +- mindspore/ccsrc/parallel/step_parallel.cc | 240 ++++++++-------- mindspore/ccsrc/parallel/step_parallel.h | 102 +++---- mindspore/ccsrc/parallel/strategy.h | 4 +- .../parallel_strategy_checkpoint.cc | 16 +- .../parallel_strategy_checkpoint.h | 10 +- .../parallel/tensor_layout/arrangement.cc | 16 +- .../parallel/tensor_layout/arrangement.h | 12 +- .../ccsrc/parallel/tensor_layout/array.cc | 6 +- .../ccsrc/parallel/tensor_layout/array.h | 4 +- .../tensor_layout/construct_operator.cc | 6 +- .../tensor_layout/construct_operator.h | 6 +- .../parallel/tensor_layout/layout_transfer.cc | 2 +- .../parallel/tensor_layout/layout_transfer.h | 2 +- mindspore/ccsrc/parallel/tensor_layout/map.cc | 12 +- mindspore/ccsrc/parallel/tensor_layout/map.h | 8 +- .../redistribution_operator_infer.cc | 6 +- .../redistribution_operator_infer.h | 2 +- .../tensor_layout/reshape_layout_transfer.cc | 2 +- .../tensor_layout/reshape_layout_transfer.h | 2 +- .../parallel/tensor_layout/shape_util.cc | 22 +- .../ccsrc/parallel/tensor_layout/shape_util.h | 22 +- .../parallel/tensor_layout/tensor_info.h | 6 +- .../parallel/tensor_layout/tensor_layout.cc | 24 +- .../parallel/tensor_layout/tensor_layout.h | 26 +- .../tensor_layout/tensor_redistribution.cc | 10 +- .../tensor_layout/tensor_redistribution.h | 6 +- mindspore/ccsrc/pipeline/action.cc | 64 ++--- mindspore/ccsrc/pipeline/action.h | 26 +- mindspore/ccsrc/pipeline/base.h | 4 +- mindspore/ccsrc/pipeline/init.cc | 6 +- .../ccsrc/pipeline/parse/data_converter.cc | 54 ++-- .../ccsrc/pipeline/parse/data_converter.h | 24 +- .../ccsrc/pipeline/parse/function_block.cc | 54 ++-- .../ccsrc/pipeline/parse/function_block.h | 44 +-- mindspore/ccsrc/pipeline/parse/parse_base.h | 10 +- .../ccsrc/pipeline/parse/python_adapter.cc | 8 +- .../ccsrc/pipeline/parse/python_adapter.h | 14 +- mindspore/ccsrc/pipeline/parse/resolve.cc | 28 +- mindspore/ccsrc/pipeline/parse/resolve.h | 24 +- mindspore/ccsrc/pipeline/pass.cc | 34 +-- mindspore/ccsrc/pipeline/pass.h | 10 +- mindspore/ccsrc/pipeline/pipeline.cc | 84 +++--- mindspore/ccsrc/pipeline/pipeline.h | 64 ++--- mindspore/ccsrc/pipeline/pipeline_ge.cc | 48 ++-- mindspore/ccsrc/pipeline/pipeline_ge.h | 20 +- .../ccsrc/pipeline/remove_value_node_dup.cc | 12 +- .../ccsrc/pipeline/remove_value_node_dup.h | 2 +- mindspore/ccsrc/pipeline/resource.cc | 14 +- mindspore/ccsrc/pipeline/resource.h | 26 +- .../ccsrc/pipeline/static_analysis/dshape.cc | 22 +- .../ccsrc/pipeline/static_analysis/dshape.h | 26 +- mindspore/ccsrc/pipeline/validator.cc | 10 +- mindspore/ccsrc/pipeline/validator.h | 6 +- .../mem_reuse/mem_dynamic_allocator.cc | 4 +- .../mem_reuse/mem_dynamic_allocator.h | 10 +- .../ccsrc/predict/generator/ir/ir_model.h | 2 +- mindspore/ccsrc/pybind_api/api_register.h | 16 +- mindspore/ccsrc/pynative/base.h | 2 +- mindspore/ccsrc/pynative/pynative_execute.cc | 20 +- mindspore/ccsrc/pynative/pynative_execute.h | 4 +- .../ccsrc/pynative/pynative_execute_ge.cc | 32 +-- .../ccsrc/pynative/pynative_execute_ge.h | 6 +- mindspore/ccsrc/transform/convert.h | 96 +++---- mindspore/ccsrc/transform/df_graph_manager.cc | 18 +- mindspore/ccsrc/transform/df_graph_manager.h | 18 +- mindspore/ccsrc/transform/graph_builder.cc | 4 +- mindspore/ccsrc/transform/graph_builder.h | 2 +- mindspore/ccsrc/transform/graph_runner.cc | 24 +- mindspore/ccsrc/transform/graph_runner.h | 10 +- mindspore/ccsrc/transform/op_adapter.h | 132 ++++----- mindspore/ccsrc/transform/op_adapter_base.h | 52 ++-- mindspore/ccsrc/transform/op_adapter_util.cc | 40 +-- mindspore/ccsrc/transform/op_adapter_util.h | 24 +- mindspore/ccsrc/transform/util.cc | 44 +-- mindspore/ccsrc/transform/util.h | 42 +-- mindspore/ccsrc/utils/any.cc | 10 +- mindspore/ccsrc/utils/any.h | 48 ++-- mindspore/ccsrc/utils/base_ref.cc | 34 +-- mindspore/ccsrc/utils/base_ref.h | 116 ++++---- mindspore/ccsrc/utils/callbacks.cc | 4 +- mindspore/ccsrc/utils/callbacks.h | 6 +- mindspore/ccsrc/utils/callbacks_ge.cc | 20 +- mindspore/ccsrc/utils/callbacks_ge.h | 4 +- mindspore/ccsrc/utils/config_manager.cc | 4 +- mindspore/ccsrc/utils/config_manager.h | 20 +- mindspore/ccsrc/utils/context/ms_context.cc | 18 +- mindspore/ccsrc/utils/context/ms_context.h | 26 +- mindspore/ccsrc/utils/counter.h | 12 +- mindspore/ccsrc/utils/graph_utils.cc | 102 +++---- mindspore/ccsrc/utils/graph_utils.h | 52 ++-- mindspore/ccsrc/utils/hashing.h | 2 +- mindspore/ccsrc/utils/misc.cc | 4 +- mindspore/ccsrc/utils/misc.h | 2 +- mindspore/ccsrc/utils/ordered_set.h | 86 +++--- mindspore/ccsrc/utils/profile.cc | 34 +-- mindspore/ccsrc/utils/profile.h | 72 ++--- mindspore/ccsrc/utils/signal.h | 12 +- mindspore/ccsrc/utils/symbolic.cc | 12 +- mindspore/ccsrc/utils/symbolic.h | 30 +- mindspore/ccsrc/utils/system/base.h | 8 +- mindspore/ccsrc/utils/system/crc32c.h | 4 +- mindspore/ccsrc/utils/system/file_system.cc | 10 +- mindspore/ccsrc/utils/system/file_system.h | 36 +-- mindspore/ccsrc/utils/utils.h | 2 +- mindspore/ccsrc/vm/segment_runner.cc | 18 +- mindspore/ccsrc/vm/segment_runner.h | 4 +- mindspore/ccsrc/vm/transform.cc | 90 +++--- mindspore/ccsrc/vm/transform.h | 66 ++--- mindspore/ccsrc/vm/vm.cc | 54 ++-- mindspore/ccsrc/vm/vm.h | 76 ++--- mindspore/ccsrc/vm/vmimpl.cc | 72 ++--- mindspore/ccsrc/vm/vmimpl.h | 68 ++--- 278 files changed, 4454 insertions(+), 4448 deletions(-) diff --git a/.clang-format b/.clang-format index 3b26784000..c6488cb358 100644 --- a/.clang-format +++ b/.clang-format @@ -94,7 +94,7 @@ PenaltyBreakString: 1000 PenaltyBreakTemplateDeclaration: 10 PenaltyExcessCharacter: 1000000 PenaltyReturnTypeOnItsOwnLine: 200 -PointerAlignment: Left +PointerAlignment: Right RawStringFormats: - Language: Cpp Delimiters: diff --git a/mindspore/ccsrc/common/utils.cc b/mindspore/ccsrc/common/utils.cc index 328a059113..7109c121e5 100644 --- a/mindspore/ccsrc/common/utils.cc +++ b/mindspore/ccsrc/common/utils.cc @@ -23,7 +23,7 @@ namespace common { const int CACHED_STR_NUM = 1 << 8; const int CACHED_STR_MASK = CACHED_STR_NUM - 1; std::vector STR_HOLDER(CACHED_STR_NUM); -const char* SafeCStr(const std::string&& str) { +const char *SafeCStr(const std::string &&str) { static std::atomic index{0}; uint32_t cur_index = index++; cur_index = cur_index & CACHED_STR_MASK; diff --git a/mindspore/ccsrc/common/utils.h b/mindspore/ccsrc/common/utils.h index 7cee933ac8..8f6e8f7c0c 100644 --- a/mindspore/ccsrc/common/utils.h +++ b/mindspore/ccsrc/common/utils.h @@ -21,16 +21,16 @@ #include #define DISABLE_COPY_AND_ASSIGN(ClassType) \ - ClassType(const ClassType&) = delete; \ - ClassType& operator=(const ClassType&) = delete; + ClassType(const ClassType &) = delete; \ + ClassType &operator=(const ClassType &) = delete; namespace mindspore { namespace common { -inline const char* SafeCStr(const std::string& str) { return str.c_str(); } -const char* SafeCStr(const std::string&& str); +inline const char *SafeCStr(const std::string &str) { return str.c_str(); } +const char *SafeCStr(const std::string &&str); -static inline std::string GetEnv(const std::string& envvar) { - const char* value = ::getenv(envvar.c_str()); +static inline std::string GetEnv(const std::string &envvar) { + const char *value = ::getenv(envvar.c_str()); if (value == nullptr) { return std::string(); diff --git a/mindspore/ccsrc/dataset/kernels/image/decode_op.h b/mindspore/ccsrc/dataset/kernels/image/decode_op.h index 50d2d3cb68..6e7180958a 100644 --- a/mindspore/ccsrc/dataset/kernels/image/decode_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/decode_op.h @@ -34,11 +34,11 @@ class DecodeOp : public TensorOp { ~DecodeOp() = default; - Status Compute(const std::shared_ptr& input, std::shared_ptr* output) override; + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; - void Print(std::ostream& out) const override { out << "DecodeOp"; } - Status OutputShape(const std::vector& inputs, std::vector& outputs) override; - Status OutputType(const std::vector& inputs, std::vector& outputs) override; + void Print(std::ostream &out) const override { out << "DecodeOp"; } + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; private: bool is_rgb_format_ = true; diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc index e7a8cc3496..a28f2bb6fd 100644 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc @@ -37,8 +37,8 @@ DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float int rnd_.seed(seed_); } -Status DistortBoundingBoxCropOp::Compute(const std::vector>& input, - std::vector>* output) { +Status DistortBoundingBoxCropOp::Compute(const std::vector> &input, + std::vector> *output) { IO_CHECK_VECTOR(input, output); if (input.size() != NumInput()) return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); @@ -98,8 +98,8 @@ Status DistortBoundingBoxCropOp::Compute(const std::vector& inputs, - std::vector& outputs) { +Status DistortBoundingBoxCropOp::OutputShape(const std::vector &inputs, + std::vector &outputs) { RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); outputs.clear(); TensorShape out = TensorShape{-1, -1}; @@ -108,7 +108,7 @@ Status DistortBoundingBoxCropOp::OutputShape(const std::vector& inp if (!outputs.empty()) return Status::OK(); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); } -Status DistortBoundingBoxCropOp::OutputType(const std::vector& inputs, std::vector& outputs) { +Status DistortBoundingBoxCropOp::OutputType(const std::vector &inputs, std::vector &outputs) { RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); outputs[0] = inputs[0]; return Status::OK(); diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h index 6d5dca99fb..749c166d59 100644 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h @@ -45,16 +45,16 @@ class DistortBoundingBoxCropOp : public TensorOp { ~DistortBoundingBoxCropOp() override = default; - void Print(std::ostream& out) const override { + void Print(std::ostream &out) const override { out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; } - Status Compute(const std::vector>& input, - std::vector>* output) override; + Status Compute(const std::vector> &input, + std::vector> *output) override; uint32_t NumInput() override { return 5; } - Status OutputShape(const std::vector& inputs, std::vector& outputs) override; - Status OutputType(const std::vector& inputs, std::vector& outputs) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + Status OutputType(const std::vector &inputs, std::vector &outputs) override; private: int32_t max_attempts_; diff --git a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc index 3cf6065659..a3cf8cefb5 100644 --- a/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/random_crop_and_resize_op.cc @@ -41,7 +41,7 @@ RandomCropAndResizeOp::RandomCropAndResizeOp(int32_t target_height, int32_t targ rnd_.seed(GetSeed()); } -Status RandomCropAndResizeOp::Compute(const std::shared_ptr& input, std::shared_ptr* output) { +Status RandomCropAndResizeOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { IO_CHECK(input, output); CHECK_FAIL_RETURN_UNEXPECTED(input->shape().Size() >= 2, "The shape of input is abnormal"); @@ -54,7 +54,7 @@ Status RandomCropAndResizeOp::Compute(const std::shared_ptr& input, std: (void)GetCropBox(h_in, w_in, &x, &y, &crop_height, &crop_width); return CropAndResize(input, output, x, y, crop_height, crop_width, target_height_, target_width_, interpolation_); } -Status RandomCropAndResizeOp::OutputShape(const std::vector& inputs, std::vector& outputs) { +Status RandomCropAndResizeOp::OutputShape(const std::vector &inputs, std::vector &outputs) { RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); outputs.clear(); TensorShape out = TensorShape{target_height_, target_width_}; @@ -63,7 +63,7 @@ Status RandomCropAndResizeOp::OutputShape(const std::vector& inputs if (!outputs.empty()) return Status::OK(); return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); } -Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int* x, int* y, int* crop_height, int* crop_width) { +Status RandomCropAndResizeOp::GetCropBox(int h_in, int w_in, int *x, int *y, int *crop_height, int *crop_width) { double scale, aspect; *crop_width = w_in; *crop_height = h_in; diff --git a/mindspore/ccsrc/debug/anf_ir_dump.h b/mindspore/ccsrc/debug/anf_ir_dump.h index 5c4bc5eacd..a53888348d 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.h +++ b/mindspore/ccsrc/debug/anf_ir_dump.h @@ -22,7 +22,7 @@ namespace mindspore { constexpr char PARALLEL_STRATEGY[] = "strategy"; -void DumpIR(const std::string& filename, const FuncGraphPtr& func_graph, bool dump_full_name = false); +void DumpIR(const std::string &filename, const FuncGraphPtr &func_graph, bool dump_full_name = false); } // namespace mindspore diff --git a/mindspore/ccsrc/debug/anf_ir_utils.cc b/mindspore/ccsrc/debug/anf_ir_utils.cc index 8e626d6f9a..6ebe3ad43f 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.cc +++ b/mindspore/ccsrc/debug/anf_ir_utils.cc @@ -44,7 +44,7 @@ const int NUM_MAX_SEQUENCE_ELEMS = 0x00FFFFFF; // get MindSpore Intermediate Representation Path std::string GetMsIrPath(void) { std::string path; - const char* path_ptr = getenv("MS_IR_PATH"); + const char *path_ptr = getenv("MS_IR_PATH"); if (path_ptr != nullptr) { path = path_ptr; char real_path[PATH_MAX] = {0}; @@ -62,13 +62,13 @@ std::string GetMsIrPath(void) { return path; } -std::string dump_obj(const py::object& obj, const std::string& path) { +std::string dump_obj(const py::object &obj, const std::string &path) { py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::object name = parse::python_adapter::CallPyModFn(mod, "dump_obj", obj, py::str(path)); return py::str(name); } -py::object load_obj(const std::string& path) { +py::object load_obj(const std::string &path) { py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE); py::object obj = parse::python_adapter::CallPyModFn(mod, "load_obj", py::str(path)); return obj; @@ -76,7 +76,7 @@ py::object load_obj(const std::string& path) { // ============================================= MindSpore IR Exporter ============================================= -std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { +std::string AnfExporter::GetNodeType(const AnfNodePtr &nd) { abstract::ShapePtr shape = nd->Shape() == nullptr ? nullptr : dyn_cast(nd->Shape()); TypePtr type = dyn_cast(nd->Type()); std::ostringstream oss; @@ -90,7 +90,7 @@ std::string AnfExporter::GetNodeType(const AnfNodePtr& nd) { return oss.str(); } -std::string AnfExporter::DumpObject(const py::object& obj, const std::string& category) const { +std::string AnfExporter::DumpObject(const py::object &obj, const std::string &category) const { std::string pkl_path = GetMsIrPath(); // if not specified env 'MS_IR_PATH', do not create any files if (pkl_path.empty() || (getenv("MS_IR_FILE") != nullptr)) { @@ -101,7 +101,7 @@ std::string AnfExporter::DumpObject(const py::object& obj, const std::string& ca return file_prefix + file_name; } -int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp) { +int AnfExporter::GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp) { if (func_graph == nullptr || param == nullptr) { return -1; } @@ -129,13 +129,13 @@ int AnfExporter::GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& // try to find index of parameter for SymbolicKeyInstance from all exported graphs // NOTICE: Suppose name of all parameters in SymbolicKeyInstance are different -int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { +int AnfExporter::GetParamIndexFromExported(const AnfNodePtr ¶m) { if (param == nullptr) { return -1; } int ret = -1; - for (const auto& item : exported) { + for (const auto &item : exported) { auto pram_iter = item.second.find(param); if (pram_iter != item.second.end()) { return pram_iter->second; @@ -144,12 +144,12 @@ int AnfExporter::GetParamIndexFromExported(const AnfNodePtr& param) { return ret; } -std::string AnfExporter::GetValueNodeText(const FuncGraphPtr& fg, const ValueNodePtr& node) { +std::string AnfExporter::GetValueNodeText(const FuncGraphPtr &fg, const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(node); return GetValueText(fg, node->value()); } -std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph) { +std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph) { auto py_funcs = mt_func_graph->GetPyFunctions(); if (py_funcs.empty()) { return ""; @@ -159,7 +159,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap oss << "{"; bool is_first = true; - for (const auto& py_func : py_funcs) { + for (const auto &py_func : py_funcs) { if (is_first) { is_first = false; } else { @@ -193,7 +193,7 @@ std::string AnfExporter::GetMultitypeFuncGraphText(const prim::MultitypeFuncGrap * ├── GradOperation * └── TupleAdd */ -std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph) { +std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph) { if (meta_func_graph == nullptr) { return ""; } @@ -244,7 +244,7 @@ std::string AnfExporter::GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_ return oss.str(); } -std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { +std::string AnfExporter::GetPrimitiveText(const PrimitivePtr &prim) { std::ostringstream oss; if (prim == nullptr) { return oss.str(); @@ -266,7 +266,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { if (prim->isa()) { auto do_signature = dyn_cast(prim); - auto& func = do_signature->function(); + auto &func = do_signature->function(); if (func->isa()) { auto sig_prim = dyn_cast(func); oss << sig_prim->GetAttrsText(); @@ -276,7 +276,7 @@ std::string AnfExporter::GetPrimitiveText(const PrimitivePtr& prim) { return oss.str(); } -std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { +std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr &ns) { std::ostringstream oss; if (ns == nullptr) { return oss.str(); @@ -288,8 +288,8 @@ std::string AnfExporter::GetNameSpaceText(const parse::NameSpacePtr& ns) { return oss.str(); } -std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, - const SymbolicKeyInstancePtr& sym_inst) { +std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, + const SymbolicKeyInstancePtr &sym_inst) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(sym_inst); AnfNodePtr sym_node = sym_inst->node(); @@ -317,7 +317,7 @@ std::string AnfExporter::GetSymbolicKeyInstanceText(const FuncGraphPtr& func_gra return oss.str(); } -std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; // output ValueList, ValueTuple ValueSequeuePtr seq = dyn_cast(value); @@ -338,12 +338,12 @@ std::string AnfExporter::GetSequenceText(const FuncGraphPtr& func_graph, const V return oss.str(); } -std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; ValueDictionaryPtr dict = value->cast(); oss << "{"; bool first_flag = true; - for (const auto& elem : dict->value()) { + for (const auto &elem : dict->value()) { if (first_flag) { first_flag = false; } else { @@ -355,7 +355,7 @@ std::string AnfExporter::GetDictText(const FuncGraphPtr& func_graph, const Value return oss.str(); } -std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& value) { +std::string AnfExporter::GetOtherValueText(const FuncGraphPtr &, const ValuePtr &value) { std::ostringstream oss; if (check_integrity_) { @@ -366,7 +366,7 @@ std::string AnfExporter::GetOtherValueText(const FuncGraphPtr&, const ValuePtr& return oss.str(); } -std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value) { +std::string AnfExporter::GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value) { std::ostringstream oss; bool is_null_ptr = (func_graph == nullptr || value == nullptr); if (is_null_ptr) { @@ -413,8 +413,8 @@ std::string AnfExporter::GetValueText(const FuncGraphPtr& func_graph, const Valu } // this function is used to output node in CNode's inputs -std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map) { +std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map) { std::ostringstream oss; if (func_graph == nullptr || node == nullptr) { return oss.str(); @@ -444,10 +444,10 @@ std::string AnfExporter::GetAnfNodeText(const FuncGraphPtr& func_graph, const An return oss.str(); } -void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vector& parameters, - OrderedMap* param_map) { +void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector ¶meters, + OrderedMap *param_map) { bool first_flag = true; - for (const AnfNodePtr& param : parameters) { + for (const AnfNodePtr ¶m : parameters) { if (first_flag) { first_flag = false; ofs << " "; @@ -479,13 +479,13 @@ void AnfExporter::OutputParameters(std::ofstream& ofs, const std::vectorinputs(); + auto &inputs = node->inputs(); if (inputs.size() > 1) { ofs << " #("; for (size_t i = 1; i < inputs.size(); ++i) { @@ -521,15 +521,15 @@ void AnfExporter::OutputStatementComment(std::ofstream& ofs, const CNodePtr& nod ofs << " #scope: " << node->scope()->name(); } -void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector& nodes, - const FuncGraphPtr& func_graph) { +void AnfExporter::OutputCNodes(std::ofstream &ofs, const std::vector &nodes, + const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } int idx = 1; std::map apply_map; - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -541,7 +541,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector } auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); std::string op_text = GetAnfNodeText(func_graph, inputs[0], apply_map); // non-return node if (node != func_graph->get_return()) { @@ -578,7 +578,7 @@ void AnfExporter::OutputCNodes(std::ofstream& ofs, const std::vector } } -void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph) { +void AnfExporter::ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -612,7 +612,7 @@ void AnfExporter::ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& fun ofs << "}\n"; } -void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { +void AnfExporter::ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -637,7 +637,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const FuncGraphPt ofs.close(); } -void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector& graphs) { +void AnfExporter::ExportFuncGraph(const std::string &filename, const std::vector &graphs) { if (graphs.empty()) { return; } @@ -650,7 +650,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector param_index = 1; - for (const auto& tagged_graph : graphs) { + for (const auto &tagged_graph : graphs) { tagged_cnodes_ = tagged_graph.second; ExportOneFuncGraph(ofs, tagged_graph.first); tagged_cnodes_.clear(); @@ -663,7 +663,7 @@ void AnfExporter::ExportFuncGraph(const std::string& filename, const std::vector } #ifdef ENABLE_DUMP_IR -void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph) { +void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -675,7 +675,7 @@ void ExportIR(const std::string& filename, const std::string& id, const FuncGrap ChangeFileMode(filename, S_IRUSR); } -void ExportIR(const std::string& filename, const std::vector& graphs) { +void ExportIR(const std::string &filename, const std::vector &graphs) { AnfExporter exporter("", false); ChangeFileMode(filename, S_IRWXU); exporter.ExportFuncGraph(filename, graphs); @@ -683,7 +683,7 @@ void ExportIR(const std::string& filename, const std::vector& graph ChangeFileMode(filename, S_IRUSR); } #else -void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { +void ExportIR(const std::string &, const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -693,7 +693,7 @@ void ExportIR(const std::string&, const std::string&, const FuncGraphPtr&) { << "please recompile source to enable it. See help of building script."; } -void ExportIR(const std::string& filename, const std::vector& graphs) { +void ExportIR(const std::string &filename, const std::vector &graphs) { static bool already_printed = false; if (already_printed) { return; @@ -732,7 +732,7 @@ enum Token : int { TOK_ERROR // file read error }; -std::map token_text = { +std::map token_text = { {TOK_INVALID, "invalid"}, // invalid token {TOK_LPARENTHESIS, "("}, // ( left parenthesis {TOK_RPARENTHESIS, ")"}, // ) right parenthesis @@ -761,14 +761,14 @@ std::map token_text = { class Lexer { public: // filename is checked in ImportIR; - explicit Lexer(const char* filename) : fin(filename) {} + explicit Lexer(const char *filename) : fin(filename) {} ~Lexer() { try { if (fin.is_open()) { fin.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file"; } catch (...) { std::string exName(abi::__cxa_current_exception_type()->name()); @@ -776,7 +776,7 @@ class Lexer { } } - bool IsSingleCharToken(char ch, Token* token_ptr) { + bool IsSingleCharToken(char ch, Token *token_ptr) { // clang-format off std::unordered_map char_to_token = { {'(', TOK_LPARENTHESIS}, @@ -806,7 +806,7 @@ class Lexer { Token GetNextToken() { #ifdef DEBUG Token token = GetNextTokenInner(); - const char* str = token_text[token]; + const char *str = token_text[token]; std::string text = (str == nullptr ? GetTokenText() : str); MS_LOG(DEBUG) << "------Parse token] " << text; return token; @@ -1064,11 +1064,11 @@ const unsigned Lexer::BUF_SIZE; class IrParser { public: - explicit IrParser(const char* filename) : lexer_(filename) {} + explicit IrParser(const char *filename) : lexer_(filename) {} ~IrParser() {} - py::object LoadObject(const std::string& file_name) const { + py::object LoadObject(const std::string &file_name) const { std::string pkl_path = GetMsIrPath(); py::object default_obj = load_obj(pkl_path + "/" + file_name); return default_obj; @@ -1087,7 +1087,7 @@ class IrParser { MS_LOG(INFO) << "Total graphs: " << func_graphs_.size(); } - Token ParseParent(FuncGraphPtr* const parent_ptr) { + Token ParseParent(FuncGraphPtr *const parent_ptr) { if (lexer_.GetNextToken() != TOK_IDENTIFIER) { return TOK_ERROR; } @@ -1168,7 +1168,7 @@ class IrParser { return func_graph; } - FuncGraphPtr ParseStatements(const FuncGraphPtr& func_graph) { + FuncGraphPtr ParseStatements(const FuncGraphPtr &func_graph) { Token tok = lexer_.SkipWhiteToken(); while (tok == TOK_VARIABLE) { if (ParseStatement(func_graph) == nullptr) { @@ -1264,56 +1264,56 @@ class IrParser { return func_graph; } - void SetBasicType(TypePtr* ptr, const TypePtr& dtype) const { + void SetBasicType(TypePtr *ptr, const TypePtr &dtype) const { if (ptr == nullptr) { return; } *ptr = dtype; } - void SetTupleType(TypePtr* ptr) { + void SetTupleType(TypePtr *ptr) { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetTupleType(TypePtr* ptr, const TypePtrList& elems) { + void SetTupleType(TypePtr *ptr, const TypePtrList &elems) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elems); } - void SetArrayType(TypePtr* const ptr, const TypePtr& elem_type, const std::vector&) { + void SetArrayType(TypePtr *const ptr, const TypePtr &elem_type, const std::vector &) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem_type); } - void SetListType(TypePtr* ptr) { + void SetListType(TypePtr *ptr) { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetListType(TypePtr* ptr, const TypePtrList& elems) { + void SetListType(TypePtr *ptr, const TypePtrList &elems) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elems); } - void SetJTaggedType(TypePtr* ptr, const TypePtr& elem) { + void SetJTaggedType(TypePtr *ptr, const TypePtr &elem) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem); } - void SetBasicType(AbstractBasePtr* ptr, const TypePtr& dtype) const { + void SetBasicType(AbstractBasePtr *ptr, const TypePtr &dtype) const { if (ptr == nullptr) { return; } @@ -1321,45 +1321,45 @@ class IrParser { } // void SetBasicType(AbstractBasePtr *ptr, const SymbolicKeyTypePtr& dtype) {} - void SetBasicType(AbstractBasePtr* const ptr, const TypeNonePtr&) const { + void SetBasicType(AbstractBasePtr *const ptr, const TypeNonePtr &) const { if (ptr == nullptr) { return; } *ptr = std::make_shared(); } - void SetBasicType(AbstractBasePtr*, const FunctionPtr&) const {} - void SetBasicType(AbstractBasePtr*, const TensorTypePtr&) const {} + void SetBasicType(AbstractBasePtr *, const FunctionPtr &) const {} + void SetBasicType(AbstractBasePtr *, const TensorTypePtr &) const {} - void SetTupleType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { + void SetTupleType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { if (ptr == nullptr) { return; } // if one of elems is nullptr, just return - if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { + if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { return; } *ptr = std::make_shared(elems); } - void SetArrayType(AbstractBasePtr* const ptr, const TypePtr& elem_type, const std::vector& shape) { + void SetArrayType(AbstractBasePtr *const ptr, const TypePtr &elem_type, const std::vector &shape) { if (ptr == nullptr) { return; } *ptr = std::make_shared(elem_type, shape); } - void SetListType(AbstractBasePtr* const ptr, const AbstractBasePtrList& elems) { + void SetListType(AbstractBasePtr *const ptr, const AbstractBasePtrList &elems) { if (ptr == nullptr) { return; } - if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr& elem) { return elem == nullptr; })) { + if (std::any_of(std::begin(elems), std::end(elems), [](const AbstractBasePtr &elem) { return elem == nullptr; })) { return; } *ptr = std::make_shared(elems); } - void SetJTaggedType(AbstractBasePtr* const ptr, const AbstractBasePtr& elem) { + void SetJTaggedType(AbstractBasePtr *const ptr, const AbstractBasePtr &elem) { if (ptr == nullptr) { return; } @@ -1367,7 +1367,7 @@ class IrParser { } template - Token ParseTypeVector(const FuncGraphPtr& func_graph, Token tok, const std::string& type, T* const ptr = nullptr) { + Token ParseTypeVector(const FuncGraphPtr &func_graph, Token tok, const std::string &type, T *const ptr = nullptr) { if (tok != TOK_LBRACKET) { MS_LOG(EXCEPTION) << "Illegal case, , wrong token start symbol."; return tok; @@ -1415,7 +1415,7 @@ class IrParser { } template - Token ParseTypeArray(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { + Token ParseTypeArray(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { if (tok != TOK_LPARENTHESIS) { if (ptr != nullptr) { SetBasicType(ptr, std::make_shared()); @@ -1454,7 +1454,7 @@ class IrParser { return lexer_.GetNextToken(); } - bool IsNumberType(const std::string& type, TypeId* typeid_ptr) { + bool IsNumberType(const std::string &type, TypeId *typeid_ptr) { // clang-format off static std::unordered_map basic_types = { {"Bool", kNumberTypeBool}, @@ -1486,7 +1486,7 @@ class IrParser { } template - void ParseNumberType(const std::string& type, TypeId typeId, T* const ptr = nullptr) { + void ParseNumberType(const std::string &type, TypeId typeId, T *const ptr = nullptr) { TypePtr dtype = nullptr; std::unordered_map type_map = { @@ -1519,7 +1519,7 @@ class IrParser { } template - Token ParseTrivalType(const std::string& type, T* const ptr = nullptr) { + Token ParseTrivalType(const std::string &type, T *const ptr = nullptr) { if (type == "NoneType") { SetBasicType(ptr, std::make_shared()); return lexer_.GetNextToken(); @@ -1541,7 +1541,7 @@ class IrParser { } template - Token ParseOneType(const FuncGraphPtr& func_graph, Token tok, T* const ptr = nullptr) { + Token ParseOneType(const FuncGraphPtr &func_graph, Token tok, T *const ptr = nullptr) { if (tok != TOK_IDENTIFIER) { return TOK_ERROR; } @@ -1588,11 +1588,11 @@ class IrParser { } } - Token ParseType(const FuncGraphPtr& func_graph, AbstractBasePtr* const abstract = nullptr) { + Token ParseType(const FuncGraphPtr &func_graph, AbstractBasePtr *const abstract = nullptr) { return ParseOneType(func_graph, lexer_.GetNextToken(), abstract); } - Token ParseAttributes(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { + Token ParseAttributes(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { Token tok = ParseAttribute(func_graph, prim); while (tok == TOK_COMMA) { tok = ParseAttribute(func_graph, prim); @@ -1603,7 +1603,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseAttribute(const FuncGraphPtr& func_graph, const PrimitivePtr& prim) { + Token ParseAttribute(const FuncGraphPtr &func_graph, const PrimitivePtr &prim) { Token tok = lexer_.GetNextToken(); if (tok != TOK_IDENTIFIER) { return TOK_ERROR; @@ -1670,7 +1670,7 @@ class IrParser { return tok == TOK_RPARENTHESIS ? func_graph : nullptr; } - FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector* const inputs_ptr) { + FuncGraphPtr ParseArguments(FuncGraphPtr func_graph, std::vector *const inputs_ptr) { Token tok = ParseArgument(func_graph, inputs_ptr); while (tok == TOK_COMMA) { tok = ParseArgument(func_graph, inputs_ptr); @@ -1681,9 +1681,9 @@ class IrParser { return func_graph; } - AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string& param_name) { + AnfNodePtr FindParameter(FuncGraphPtr func_graph, const std::string ¶m_name) { while (func_graph != nullptr) { - for (auto& ptr : func_graph->parameters()) { + for (auto &ptr : func_graph->parameters()) { MS_EXCEPTION_IF_NULL(ptr); ParameterPtr param = ptr->cast(); MS_EXCEPTION_IF_NULL(param); @@ -1701,12 +1701,12 @@ class IrParser { return nullptr; } - bool Match(const std::string& str, const std::string& pattern) const { + bool Match(const std::string &str, const std::string &pattern) const { return strncmp(str.c_str(), pattern.c_str(), pattern.length()) == 0; } template - Token ParseScalar(ValuePtr* const val_ptr) { + Token ParseScalar(ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_NUMBER) { return TOK_ERROR; } @@ -1725,7 +1725,7 @@ class IrParser { } template - Token ParseScalar(ValuePtr* const val_ptr, Token tok) { + Token ParseScalar(ValuePtr *const val_ptr, Token tok) { if (tok != TOK_LPARENTHESIS) { *val_ptr = std::make_shared(); return tok; @@ -1735,7 +1735,7 @@ class IrParser { } template - Token ParseScalar(ValuePtr* const val_ptr, Token tok) { + Token ParseScalar(ValuePtr *const val_ptr, Token tok) { if (tok != TOK_LPARENTHESIS) { *val_ptr = std::make_shared(nbits); return tok; @@ -1745,7 +1745,7 @@ class IrParser { } template - T StringToScalar(const std::string& text) { + T StringToScalar(const std::string &text) { std::stringstream ss; T value; ss << text; @@ -1753,7 +1753,7 @@ class IrParser { return value; } - Token ParseTensor(ValuePtr* const val_ptr) { + Token ParseTensor(ValuePtr *const val_ptr) { // parse type TypeId type; if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { @@ -1803,7 +1803,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParsePrimType(Token tok, PrimType* prim_type_ptr) { + Token ParsePrimType(Token tok, PrimType *prim_type_ptr) { if (tok != TOK_LBRACE) { return tok; } @@ -1830,7 +1830,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { + Token ParseMultitypeFuncGraphItem(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { if (tok != TOK_LPARENTHESIS) { return TOK_ERROR; } @@ -1855,7 +1855,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr& mt_func_graph, Token tok) { + Token ParseMultitypeFuncGraph(const prim::MultitypeFuncGraphPtr &mt_func_graph, Token tok) { if (tok != TOK_LBRACE) { return tok; } @@ -1868,7 +1868,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseBoolValue(const std::string& key, bool* val_ptr) { + Token ParseBoolValue(const std::string &key, bool *val_ptr) { if (lexer_.GetNextToken() != TOK_IDENTIFIER || lexer_.GetTokenText() != key) { return TOK_ERROR; } @@ -1892,7 +1892,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseValueGradOperation(const std::string& name, ValuePtr* const val_ptr) { + Token ParseValueGradOperation(const std::string &name, ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_LBRACE) { return TOK_ERROR; } @@ -1920,7 +1920,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseSymbolicKeyInstance(const FuncGraphPtr& func_graph, AnfNodePtr* const node_ptr = nullptr) { + Token ParseSymbolicKeyInstance(const FuncGraphPtr &func_graph, AnfNodePtr *const node_ptr = nullptr) { if (lexer_.GetNextToken() != TOK_LPARENTHESIS) { return TOK_ERROR; } @@ -1951,7 +1951,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParsePrimitivePy(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* const val_ptr) { + Token ParsePrimitivePy(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *const val_ptr) { if (lexer_.GetNextToken() != TOK_AT_FILE) { return TOK_ERROR; } @@ -1984,7 +1984,7 @@ class IrParser { return next; } - Token ParseValueGraphAndNamespace(const std::string& id, ValuePtr* val_ptr) { + Token ParseValueGraphAndNamespace(const std::string &id, ValuePtr *val_ptr) { if (Match(id, "MultitypeFuncGraph::")) { std::string name = id.substr(strlen("MultitypeFuncGraph::")); auto mt_func_graph = std::make_shared(name); @@ -2024,8 +2024,8 @@ class IrParser { } } - Token ParseValueBasic(const FuncGraphPtr& func_graph, const std::string& id, ValuePtr* val_ptr, - AnfNodePtr* const node_ptr = nullptr) { + Token ParseValueBasic(const FuncGraphPtr &func_graph, const std::string &id, ValuePtr *val_ptr, + AnfNodePtr *const node_ptr = nullptr) { if (id == "None") { *val_ptr = std::make_shared(); return lexer_.GetNextToken(); @@ -2075,9 +2075,9 @@ class IrParser { } } - Token SetListOrTupleValue(const FuncGraphPtr& func_graph, Token left_tok, Token next, bool node_is_valid, - const std::vector& elems, const std::vector& nodes, - ValuePtr* const val_ptr, AnfNodePtr* node_ptr) { + Token SetListOrTupleValue(const FuncGraphPtr &func_graph, Token left_tok, Token next, bool node_is_valid, + const std::vector &elems, const std::vector &nodes, + ValuePtr *const val_ptr, AnfNodePtr *node_ptr) { if (left_tok == TOK_LPARENTHESIS && next == TOK_RPARENTHESIS) { if (node_is_valid && node_ptr != nullptr) { MS_EXCEPTION_IF_NULL(func_graph); @@ -2097,8 +2097,8 @@ class IrParser { } } - Token ParseListOrTupleValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, - AnfNodePtr* node_ptr = nullptr) { + Token ParseListOrTupleValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, + AnfNodePtr *node_ptr = nullptr) { Token left_tok = tok; std::vector elems; @@ -2138,7 +2138,7 @@ class IrParser { return SetListOrTupleValue(func_graph, left_tok, next, node_is_valid, elems, nodes, val_ptr, node_ptr); } - Token ParseValue(const FuncGraphPtr& func_graph, Token tok, ValuePtr* const val_ptr, AnfNodePtr* node_ptr = nullptr) { + Token ParseValue(const FuncGraphPtr &func_graph, Token tok, ValuePtr *const val_ptr, AnfNodePtr *node_ptr = nullptr) { // tuple or list if (tok == TOK_LPARENTHESIS || tok == TOK_LBRACKET) { return ParseListOrTupleValue(func_graph, tok, val_ptr, node_ptr); @@ -2152,7 +2152,7 @@ class IrParser { return TOK_ERROR; } - Token ParseItem(const FuncGraphPtr& func_graph, AnfNodePtr* node_ptr, ValuePtr* const val_ptr, + Token ParseItem(const FuncGraphPtr &func_graph, AnfNodePtr *node_ptr, ValuePtr *const val_ptr, Token tok = TOK_INVALID) { if (tok == TOK_INVALID) { tok = lexer_.GetNextToken(); @@ -2193,7 +2193,7 @@ class IrParser { return lexer_.GetNextToken(); } - Token ParseArgument(const FuncGraphPtr& func_graph, std::vector* const inputs_ptr) { + Token ParseArgument(const FuncGraphPtr &func_graph, std::vector *const inputs_ptr) { Token tok = lexer_.GetNextToken(); if (tok == TOK_RPARENTHESIS) { return tok; @@ -2208,7 +2208,7 @@ class IrParser { return tok; } - const std::vector& GetFuncGraphs() const { return func_graphs_; } + const std::vector &GetFuncGraphs() const { return func_graphs_; } private: Lexer lexer_; @@ -2226,14 +2226,14 @@ class IrParser { std::map param_nodes_; // map parameter name to parameter }; -std::vector ImportIR(const std::string& filename) { +std::vector ImportIR(const std::string &filename) { IrParser parser(filename.c_str()); parser.ParseFile(); return parser.GetFuncGraphs(); } #ifdef ENABLE_DUMP_IR -void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { +void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix) { if (func_graph == nullptr) { MS_LOG(ERROR) << "Func graph is nullptr"; return; @@ -2253,7 +2253,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { return; } char real_path[PATH_MAX] = {0}; - char* real_path_ret = nullptr; + char *real_path_ret = nullptr; #if defined(_WIN32) || defined(_WIN64) real_path_ret = _fullpath(real_path, file_path.c_str(), PATH_MAX); #else @@ -2281,7 +2281,7 @@ void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix) { ChangeFileMode(file_path, S_IRUSR); } #else -void DumpIRProto(const FuncGraphPtr&, const std::string&) { +void DumpIRProto(const FuncGraphPtr &, const std::string &) { static bool already_printed = false; if (already_printed) { return; diff --git a/mindspore/ccsrc/debug/anf_ir_utils.h b/mindspore/ccsrc/debug/anf_ir_utils.h index 5342c1ab96..6c8601c4af 100644 --- a/mindspore/ccsrc/debug/anf_ir_utils.h +++ b/mindspore/ccsrc/debug/anf_ir_utils.h @@ -39,7 +39,7 @@ namespace mindspore { struct ParamPtrEqual { - bool operator()(AnfNodePtr const& t1, AnfNodePtr const& t2) const { + bool operator()(AnfNodePtr const &t1, AnfNodePtr const &t2) const { const ParameterPtr param1 = dyn_cast(t1); const ParameterPtr param2 = dyn_cast(t2); @@ -52,7 +52,7 @@ struct ParamPtrEqual { }; struct ParamPtrHasher { - std::size_t operator()(AnfNodePtr const& param) const { + std::size_t operator()(AnfNodePtr const ¶m) const { const ParameterPtr parameter = dyn_cast(param); if (parameter == nullptr) { return 0; @@ -64,39 +64,39 @@ struct ParamPtrHasher { class AnfExporter { public: - explicit AnfExporter(const std::string& id, bool export_used = true, bool check_integrity = false) + explicit AnfExporter(const std::string &id, bool export_used = true, bool check_integrity = false) : param_index(-1), id_(id), export_used_(export_used), check_integrity_(check_integrity) { func_graph_set.clear(); exported.clear(); } virtual ~AnfExporter() {} - void ExportFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); - void ExportFuncGraph(const std::string& filename, const std::vector& graphs); + void ExportFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); + void ExportFuncGraph(const std::string &filename, const std::vector &graphs); protected: - virtual std::string GetNodeType(const AnfNodePtr& nd); - int GetParamIndex(const FuncGraphPtr& func_graph, const AnfNodePtr& param, bool throw_excp = true); - int GetParamIndexFromExported(const AnfNodePtr& param); - std::string DumpObject(const py::object& obj, const std::string& category) const; - std::string GetValueNodeText(const FuncGraphPtr& func_graph, const ValueNodePtr& node); - std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr& mt_func_graph); - std::string GetSymbolicKeyInstanceText(const FuncGraphPtr& func_graph, const SymbolicKeyInstancePtr& sym_inst); - std::string GetSequenceText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetOtherValueText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetPrimitiveText(const PrimitivePtr& prim); - std::string GetDictText(const FuncGraphPtr& func_graph, const ValuePtr& value); - std::string GetNameSpaceText(const parse::NameSpacePtr& ns); - std::string GetMetaFuncGraphText(const MetaFuncGraphPtr& meta_func_graph); - std::string GetAnfNodeText(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map); - void ExportOneFuncGraph(std::ofstream& ofs, const FuncGraphPtr& func_graph); - void OutputParameters(std::ofstream& ofs, const std::vector& parameters, - OrderedMap* param_map); - - void OutputStatementComment(std::ofstream& ofs, const CNodePtr& node); - void OutputCNodes(std::ofstream& ofs, const std::vector& nodes, const FuncGraphPtr& func_graph); + virtual std::string GetNodeType(const AnfNodePtr &nd); + int GetParamIndex(const FuncGraphPtr &func_graph, const AnfNodePtr ¶m, bool throw_excp = true); + int GetParamIndexFromExported(const AnfNodePtr ¶m); + std::string DumpObject(const py::object &obj, const std::string &category) const; + std::string GetValueNodeText(const FuncGraphPtr &func_graph, const ValueNodePtr &node); + std::string GetMultitypeFuncGraphText(const prim::MultitypeFuncGraphPtr &mt_func_graph); + std::string GetSymbolicKeyInstanceText(const FuncGraphPtr &func_graph, const SymbolicKeyInstancePtr &sym_inst); + std::string GetSequenceText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetOtherValueText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetPrimitiveText(const PrimitivePtr &prim); + std::string GetDictText(const FuncGraphPtr &func_graph, const ValuePtr &value); + std::string GetNameSpaceText(const parse::NameSpacePtr &ns); + std::string GetMetaFuncGraphText(const MetaFuncGraphPtr &meta_func_graph); + std::string GetAnfNodeText(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map); + void ExportOneFuncGraph(std::ofstream &ofs, const FuncGraphPtr &func_graph); + void OutputParameters(std::ofstream &ofs, const std::vector ¶meters, + OrderedMap *param_map); + + void OutputStatementComment(std::ofstream &ofs, const CNodePtr &node); + void OutputCNodes(std::ofstream &ofs, const std::vector &nodes, const FuncGraphPtr &func_graph); int param_index; OrderedSet func_graph_set{}; @@ -108,16 +108,16 @@ class AnfExporter { abstract::AnfNodeConfigPtr node_cfg_ = nullptr; }; -void ExportIR(const std::string& filename, const std::string& id, const FuncGraphPtr& func_graph); -void ExportIR(const std::string& filename, const std::vector& graphs); +void ExportIR(const std::string &filename, const std::string &id, const FuncGraphPtr &func_graph); +void ExportIR(const std::string &filename, const std::vector &graphs); -std::vector ImportIR(const std::string& filename); +std::vector ImportIR(const std::string &filename); -std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); +std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); -void DumpIRProto(const FuncGraphPtr& func_graph, const std::string& suffix); +void DumpIRProto(const FuncGraphPtr &func_graph, const std::string &suffix); -std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); } // namespace mindspore #endif // MINDSPORE_CCSRC_DEBUG_ANF_IR_UTILS_H_ diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index 3e8cbfba19..d3b92532fa 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -34,7 +34,7 @@ namespace draw { namespace { // Only for ValueNode -std::string ValueType(const ValueNodePtr& node) { +std::string ValueType(const ValueNodePtr &node) { if (node == nullptr) { return ""; } @@ -43,7 +43,7 @@ std::string ValueType(const ValueNodePtr& node) { return v->type_name(); } -std::string ReplaceSpecialChar(const std::string& str) { +std::string ReplaceSpecialChar(const std::string &str) { std::ostringstream oss; for (size_t i = 0; i < str.size(); i++) { if (str[i] == '<') { @@ -59,12 +59,12 @@ std::string ReplaceSpecialChar(const std::string& str) { } // namespace // API of debug utils -void DrawNodes(const std::vector& nodes, OrderedMap>* sub_graphs, +void DrawNodes(const std::vector &nodes, OrderedMap> *sub_graphs, bool is_user) { if (sub_graphs == nullptr) { return; } - for (auto& nd : nodes) { + for (auto &nd : nodes) { MS_EXCEPTION_IF_NULL(nd); auto sub_graph = nd->func_graph(); if (sub_graph != nullptr) { @@ -84,16 +84,16 @@ void DrawNodes(const std::vector& nodes, OrderedMap& nodes, - OrderedMap>* sub_graphs) { +void DrawValueNodes(const std::vector &nodes, + OrderedMap> *sub_graphs) { if (sub_graphs == nullptr) { return; } int dup_idx = 0; - for (auto& nd : nodes) { - for (auto& t : SuccIncoming(nd)) { + for (auto &nd : nodes) { + for (auto &t : SuccIncoming(nd)) { MS_EXCEPTION_IF_NULL(t); MS_EXCEPTION_IF_NULL(nd); if (t->isa() && (*sub_graphs).find(nd->func_graph()) != (*sub_graphs).end()) { @@ -107,7 +107,7 @@ void DrawValueNodes(const std::vector& nodes, } } -void DrawEdges(const std::vector& nodes, const std::shared_ptr& digraph, bool is_user) { +void DrawEdges(const std::vector &nodes, const std::shared_ptr &digraph, bool is_user) { if (digraph == nullptr) { return; } @@ -120,11 +120,11 @@ void DrawEdges(const std::vector& nodes, const std::shared_ptrisa() || t->isa()) { if ((!is_user) || (i != 0)) { @@ -143,7 +143,7 @@ void DrawEdges(const std::vector& nodes, const std::shared_ptrSubGraph(gsub.first, gsub.second); } @@ -182,18 +182,18 @@ void DrawByOpt(std::string filename, const FuncGraphPtr& func_graph, bool is_use } #ifdef ENABLE_DUMP_IR -void Draw(const std::string& filename, const FuncGraphPtr& func_graph) { +void Draw(const std::string &filename, const FuncGraphPtr &func_graph) { const std::string dot_suffix = ".dot"; std::string filename_with_suffix = (filename.rfind(dot_suffix) != (filename.size() - dot_suffix.size())) ? (filename + dot_suffix) : filename; DrawByOpt(filename_with_suffix, func_graph, false); } -void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph) { +void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph) { DrawByOpt(filename, func_graph, true); } #else -void Draw(const std::string&, const FuncGraphPtr&) { +void Draw(const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -203,7 +203,7 @@ void Draw(const std::string&, const FuncGraphPtr&) { << "please recompile source to enable it. See help of building script."; } -void DrawUserFuncGraph(const std::string&, const FuncGraphPtr&) { +void DrawUserFuncGraph(const std::string &, const FuncGraphPtr &) { static bool already_printed = false; if (already_printed) { return; @@ -234,7 +234,7 @@ std::string Graphviz::Shape(AnfNodePtr node) { return "plaintext"; } -std::string Graphviz::Color(const AnfNodePtr& node) { +std::string Graphviz::Color(const AnfNodePtr &node) { if (node == nullptr) { return ""; } @@ -259,7 +259,7 @@ void BaseDigraph::Start() { buffer_ << "compound=true" << std::endl; } -void BaseDigraph::Head(const AnfNodePtr& node, int id) { +void BaseDigraph::Head(const AnfNodePtr &node, int id) { if (node == nullptr) { return; } @@ -270,7 +270,7 @@ void BaseDigraph::Head(const AnfNodePtr& node, int id) { } } -void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { +void BaseDigraph::Tail(const AnfNodePtr &node, int idx, int id) { if (node == nullptr) { return; } @@ -279,7 +279,7 @@ void BaseDigraph::Tail(const AnfNodePtr& node, int idx, int id) { buffer_ << ":" << idx; } -void BaseDigraph::Tail(const FuncGraphPtr& func_graph) { +void BaseDigraph::Tail(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return; } @@ -304,12 +304,12 @@ void BaseDigraph::End() { } } -void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { +void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { buffer_ << "parameters_" << key << "[shape=plaintext "; buffer_ << "label=<"; buffer_ << ""; int count = 0; - for (auto& parameter : key->parameters()) { + for (auto ¶meter : key->parameters()) { buffer_ << "
parameters
"; buffer_ << parameter->ToString(); auto py_p = dyn_cast(parameter)->default_param(); @@ -331,7 +331,7 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr& key) { buffer_ << "
>,];"; } -void BaseDigraph::SubGraph(const FuncGraphPtr& key, const std::shared_ptr& gsub) { +void BaseDigraph::SubGraph(const FuncGraphPtr &key, const std::shared_ptr &gsub) { if (key == nullptr || gsub == nullptr) { return; } @@ -361,12 +361,12 @@ Digraph::~Digraph() { if (fout_.is_open()) { fout_.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file " << filename_; } } -static std::string ReplaceAll(std::string str, const std::string& from, const std::string& to) { +static std::string ReplaceAll(std::string str, const std::string &from, const std::string &to) { size_t start_pos = 0; while ((start_pos = str.find(from, start_pos)) != std::string::npos) { (void)str.replace(start_pos, from.length(), to); @@ -375,7 +375,7 @@ static std::string ReplaceAll(std::string str, const std::string& from, const st return str; } -static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { +static void DrawValueNode(Graphviz *const graph_obj, const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(graph_obj); graph_obj->buffer() << "label=<"; @@ -410,7 +410,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { graph_obj->buffer() << ""; graph_obj->buffer() << "
"; int i = 0; - for (const auto& attr : attrs) { + for (const auto &attr : attrs) { if (i != 0) { graph_obj->buffer() << "
"; } @@ -425,7 +425,7 @@ static void DrawValueNode(Graphviz* const graph_obj, const ValueNodePtr& node) { graph_obj->buffer() << "
>,"; } -static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { +static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr) { return; } @@ -444,7 +444,7 @@ static void DrawParallelInfo(Graphviz* const graph_obj, const CNodePtr& node) { } } -static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { +static void DrawCNode(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr || node->size() == 0) { return; } @@ -484,7 +484,7 @@ static void DrawCNode(Graphviz* const graph_obj, const CNodePtr& node) { } graph_obj->buffer() << ">"; int i = 0; - for (auto& attr : attrs) { + for (auto &attr : attrs) { if (i != 0) { graph_obj->buffer() << "
"; } @@ -567,7 +567,7 @@ ModelDigraph::~ModelDigraph() { if (fout_.is_open()) { fout_.close(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "exception when closing file " << filename_; } } diff --git a/mindspore/ccsrc/debug/draw.h b/mindspore/ccsrc/debug/draw.h index 4781a6c231..7804c6e94a 100644 --- a/mindspore/ccsrc/debug/draw.h +++ b/mindspore/ccsrc/debug/draw.h @@ -31,9 +31,9 @@ namespace parse = mindspore::parse; class Graphviz { public: - Graphviz(const std::string& name, const std::string& filename) : name_(name), filename_(filename), fout_(filename_) {} + Graphviz(const std::string &name, const std::string &filename) : name_(name), filename_(filename), fout_(filename_) {} - explicit Graphviz(const std::string& name) : name_(name) {} + explicit Graphviz(const std::string &name) : name_(name) {} virtual ~Graphviz() {} @@ -41,8 +41,8 @@ class Graphviz { virtual void End() {} virtual std::string Shape(AnfNodePtr node); - std::string Color(const AnfNodePtr& node); - std::ostringstream& buffer() { return buffer_; } + std::string Color(const AnfNodePtr &node); + std::ostringstream &buffer() { return buffer_; } std::ostringstream buffer_; protected: @@ -53,8 +53,8 @@ class Graphviz { class BaseDigraph : public Graphviz { public: - BaseDigraph(const std::string& name, const std::string& filename) : Graphviz(name, filename) {} - explicit BaseDigraph(const std::string& name) : Graphviz(name) {} + BaseDigraph(const std::string &name, const std::string &filename) : Graphviz(name, filename) {} + explicit BaseDigraph(const std::string &name) : Graphviz(name) {} ~BaseDigraph() override = default; virtual void Node(AnfNodePtr node, int id = 0) = 0; @@ -63,21 +63,21 @@ class BaseDigraph : public Graphviz { void Start() override; void End() override; virtual void Edge(AnfNodePtr start, FuncGraphPtr end, int id_start); - void FuncGraphParameters(const FuncGraphPtr& key); - void SubGraph(const FuncGraphPtr& key, const std::shared_ptr& gsub); + void FuncGraphParameters(const FuncGraphPtr &key); + void SubGraph(const FuncGraphPtr &key, const std::shared_ptr &gsub); - const std::string& name() const { return name_; } + const std::string &name() const { return name_; } protected: - void Head(const AnfNodePtr& node, int id = 0); - void Tail(const AnfNodePtr& node, int idx, int id = 0); - void Tail(const FuncGraphPtr& func_graph); + void Head(const AnfNodePtr &node, int id = 0); + void Tail(const AnfNodePtr &node, int idx, int id = 0); + void Tail(const FuncGraphPtr &func_graph); }; class Digraph : public BaseDigraph { public: - Digraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} - explicit Digraph(const std::string& name) : BaseDigraph(name) {} + Digraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} + explicit Digraph(const std::string &name) : BaseDigraph(name) {} ~Digraph() override; void Node(AnfNodePtr node, int id = 0) override; @@ -86,8 +86,8 @@ class Digraph : public BaseDigraph { class ModelDigraph : public BaseDigraph { public: - ModelDigraph(const std::string& name, const std::string& filename) : BaseDigraph(name, filename) {} - explicit ModelDigraph(const std::string& name) : BaseDigraph(name) {} + ModelDigraph(const std::string &name, const std::string &filename) : BaseDigraph(name, filename) {} + explicit ModelDigraph(const std::string &name) : BaseDigraph(name) {} ~ModelDigraph() override; std::string Shape(AnfNodePtr node) override; @@ -96,8 +96,8 @@ class ModelDigraph : public BaseDigraph { }; // API to draw -void Draw(const std::string& filename, const FuncGraphPtr& func_graph); -void DrawUserFuncGraph(const std::string& filename, const FuncGraphPtr& func_graph); +void Draw(const std::string &filename, const FuncGraphPtr &func_graph); +void DrawUserFuncGraph(const std::string &filename, const FuncGraphPtr &func_graph); } // namespace draw } // namespace mindspore diff --git a/mindspore/ccsrc/debug/dump_proto.cc b/mindspore/ccsrc/debug/dump_proto.cc index a7a1e208a4..83ab1e4505 100644 --- a/mindspore/ccsrc/debug/dump_proto.cc +++ b/mindspore/ccsrc/debug/dump_proto.cc @@ -33,38 +33,38 @@ class ProtoExporter { ProtoExporter() {} ~ProtoExporter() {} - std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph); + std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph); private: void InitModelInfo(); - void GetOpNodeTypeAndAttrs(const FuncGraphPtr& func_graph, const AnfNodePtr& node, irpb::NodeProto* node_proto); - std::string GetOpNodeInputId(const FuncGraphPtr& func_graph, const AnfNodePtr& node, - const std::map& apply_map, - std::map* const_map_ptr); - void SetValueToProto(const ValuePtr& attr_value, irpb::ValueProto* value_proto); - void SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto); - void SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto); - void SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto); - void SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto); - void SetNodeOutputType(const TypePtr& node, const BaseShapePtr& shape, irpb::TypeProto* type_proto); - - void ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); - void ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto); - void ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, - std::map* const_map_ptr); - void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* apply_map_ptr, - std::map* const_map_ptr, irpb::GraphProto* graph_proto); - void ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, - const std::map& apply_map, std::map* const_map_ptr, - irpb::GraphProto* graph_proto); - void ExportValueNodes(const std::map& const_map, irpb::GraphProto* graph_proto); + void GetOpNodeTypeAndAttrs(const FuncGraphPtr &func_graph, const AnfNodePtr &node, irpb::NodeProto *node_proto); + std::string GetOpNodeInputId(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr); + void SetValueToProto(const ValuePtr &attr_value, irpb::ValueProto *value_proto); + void SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto); + void SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto); + void SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto); + void SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto); + void SetNodeOutputType(const TypePtr &node, const BaseShapePtr &shape, irpb::TypeProto *type_proto); + + void ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto); + void ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, + std::map *const_map_ptr); + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *apply_map_ptr, + std::map *const_map_ptr, irpb::GraphProto *graph_proto); + void ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, std::map *const_map_ptr, + irpb::GraphProto *graph_proto); + void ExportValueNodes(const std::map &const_map, irpb::GraphProto *graph_proto); static std::string GetConstNodeId(size_t idx) { return std::string("cst") + std::to_string(idx); } irpb::ModelProto model_; }; -static irpb::DataType GetNumberDataType(const TypePtr& type) { +static irpb::DataType GetNumberDataType(const TypePtr &type) { switch (type->type_id()) { case kNumberTypeBool: return irpb::DT_BOOL; @@ -101,7 +101,7 @@ static irpb::DataType GetNumberDataType(const TypePtr& type) { } } -void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& shape, irpb::TypeProto* type_proto) { +void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &shape, irpb::TypeProto *type_proto) { if (type_proto == nullptr) { return; } @@ -116,14 +116,14 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s type_proto->set_data_type(irpb::DT_TENSOR); if (shape != nullptr && shape->isa()) { abstract::ShapePtr shape_info = dyn_cast(shape); - for (const auto& elem : shape_info->shape()) { + for (const auto &elem : shape_info->shape()) { type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem); } } } else if (type->isa()) { TuplePtr tuple_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_TUPLE); - for (const auto& elem_type : tuple_type->elements()) { + for (const auto &elem_type : tuple_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { @@ -131,7 +131,7 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s } else if (type->isa()) { ListPtr list_type = dyn_cast(type); type_proto->set_data_type(irpb::DT_LIST); - for (const auto& elem_type : list_type->elements()) { + for (const auto &elem_type : list_type->elements()) { SetNodeOutputType(elem_type, nullptr, type_proto->mutable_sequence_type()->add_elem_types()); } } else if (type->isa()) { @@ -153,20 +153,20 @@ void ProtoExporter::SetNodeOutputType(const TypePtr& type, const BaseShapePtr& s } } -void ProtoExporter::SetNodeOutputType(const AnfNodePtr& node, irpb::TypeProto* type_proto) { +void ProtoExporter::SetNodeOutputType(const AnfNodePtr &node, irpb::TypeProto *type_proto) { if (node == nullptr || type_proto == nullptr) { return; } SetNodeOutputType(node->Type(), node->Shape(), type_proto); } -void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetValueToProto(const ValuePtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const StringImmPtr& value = dyn_cast(val); + const StringImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_STRING); value_proto->set_str_val(value->value()); } else if (val->isa()) { @@ -195,15 +195,15 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value } else if (val->isa()) { tensor::TensorPtr tensor_ptr = dyn_cast(val); value_proto->set_dtype(irpb::DT_TENSOR); - irpb::TensorProto* tensor_proto = value_proto->mutable_tensor_val(); + irpb::TensorProto *tensor_proto = value_proto->mutable_tensor_val(); tensor_proto->set_data_type(GetNumberDataType(tensor_ptr->Dtype())); - for (auto& elem : tensor_ptr->shape()) { + for (auto &elem : tensor_ptr->shape()) { tensor_proto->add_dims(elem); } } else if (val->isa()) { value_proto->set_dtype(irpb::DT_TYPE); - irpb::TypeProto* type_proto = value_proto->mutable_type_val(); + irpb::TypeProto *type_proto = value_proto->mutable_type_val(); type_proto->set_data_type(irpb::DT_TENSOR); TypePtr elem_type = dyn_cast(val)->element(); type_proto->mutable_tensor_type()->set_elem_type(GetNumberDataType(elem_type)); @@ -212,53 +212,53 @@ void ProtoExporter::SetValueToProto(const ValuePtr& val, irpb::ValueProto* value } } -void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetScalarToProto(const ScalarPtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const BoolImmPtr& value = dyn_cast(val); + const BoolImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_BOOL); value_proto->set_bool_val(value->value()); } else if (val->isa()) { - const Int8ImmPtr& value = dyn_cast(val); + const Int8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT8); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int16ImmPtr& value = dyn_cast(val); + const Int16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT16); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int32ImmPtr& value = dyn_cast(val); + const Int32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT32); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const Int64ImmPtr& value = dyn_cast(val); + const Int64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_INT64); value_proto->set_int_val(value->value()); } else if (val->isa()) { - const UInt8ImmPtr& value = dyn_cast(val); + const UInt8ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT8); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt16ImmPtr& value = dyn_cast(val); + const UInt16ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT16); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt32ImmPtr& value = dyn_cast(val); + const UInt32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT32); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const UInt64ImmPtr& value = dyn_cast(val); + const UInt64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_UINT64); value_proto->set_uint_val(value->value()); } else if (val->isa()) { - const FP32ImmPtr& value = dyn_cast(val); + const FP32ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_FLOAT32); value_proto->set_float_val(value->value()); } else if (val->isa()) { - const FP64ImmPtr& value = dyn_cast(val); + const FP64ImmPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_FLOAT64); value_proto->set_double_val(value->value()); } else { @@ -266,40 +266,40 @@ void ProtoExporter::SetScalarToProto(const ScalarPtr& val, irpb::ValueProto* val } } -void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetSequenceToProto(const ValueSequeuePtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } if (val->isa()) { - const ValueTuplePtr& value = dyn_cast(val); + const ValueTuplePtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_TUPLE); - for (const auto& item : value->value()) { + for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } else if (val->isa()) { - const ValueListPtr& value = dyn_cast(val); + const ValueListPtr &value = dyn_cast(val); value_proto->set_dtype(irpb::DT_LIST); - for (const auto& item : value->value()) { + for (const auto &item : value->value()) { SetValueToProto(item, value_proto->add_values()); } } } -void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr& val, irpb::ValueProto* value_proto) { +void ProtoExporter::SetDictionaryToProto(const ValueDictionaryPtr &val, irpb::ValueProto *value_proto) { if (val == nullptr || value_proto == nullptr) { return; } value_proto->set_dtype(irpb::DT_DICT); - for (const auto& item : val->value()) { - irpb::NamedValueProto* named_val = value_proto->add_dict_val(); + for (const auto &item : val->value()) { + irpb::NamedValueProto *named_val = value_proto->add_dict_val(); named_val->set_key(item.first); SetValueToProto(item.second, named_val->mutable_value()); } } -void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& node, irpb::NodeProto* node_proto) { +void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr &, const AnfNodePtr &node, irpb::NodeProto *node_proto) { if (node == nullptr || node_proto == nullptr) { return; } @@ -312,19 +312,19 @@ void ProtoExporter::GetOpNodeTypeAndAttrs(const FuncGraphPtr&, const AnfNodePtr& MS_LOG(EXCEPTION) << "Op node is not primitive: " << node->ToString(); } - const PrimitivePtr& prim = GetValueNode(node); + const PrimitivePtr &prim = GetValueNode(node); node_proto->set_op_type(prim->name()); - for (const auto& attr : prim->attrs()) { - irpb::AttributeProto* attr_proto = node_proto->add_attribute(); + for (const auto &attr : prim->attrs()) { + irpb::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name(attr.first); SetValueToProto(attr.second, attr_proto->mutable_value()); } node_proto->set_scope(node->scope()->name()); } -std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePtr& node, - const std::map& apply_map, - std::map* const_map_ptr) { +std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr &, const AnfNodePtr &node, + const std::map &apply_map, + std::map *const_map_ptr) { if (node == nullptr || const_map_ptr == nullptr) { return ""; } @@ -354,18 +354,18 @@ std::string ProtoExporter::GetOpNodeInputId(const FuncGraphPtr&, const AnfNodePt MS_LOG(EXCEPTION) << "Unknown node type. node is '" << node->ToString() << "'"; } -std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { +std::string ProtoExporter::GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ""; } InitModelInfo(); - irpb::GraphProto* graph_proto = model_.mutable_graph(); + irpb::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_.SerializeAsString(); } -void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } @@ -383,14 +383,14 @@ void ProtoExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, irpb::GraphP ExportValueNodes(const_map, graph_proto); } -void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportParameters(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || graph_proto == nullptr) { return; } std::vector parameters = func_graph->parameters(); - for (auto& param : parameters) { - irpb::ParameterProto* param_proto = graph_proto->add_parameters(); + for (auto ¶m : parameters) { + irpb::ParameterProto *param_proto = graph_proto->add_parameters(); param_proto->set_name(param->ToString()); SetNodeOutputType(param, param_proto->mutable_type()); @@ -402,15 +402,15 @@ void ProtoExporter::ExportParameters(const FuncGraphPtr& func_graph, irpb::Graph } } -void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProto* graph_proto, - std::map* const_map_ptr) { +void ProtoExporter::ExportCNodes(const FuncGraphPtr &func_graph, irpb::GraphProto *graph_proto, + std::map *const_map_ptr) { if (func_graph == nullptr || graph_proto == nullptr || const_map_ptr == nullptr) { return; } // topo sort nodes std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::map apply_map; - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -424,9 +424,9 @@ void ProtoExporter::ExportCNodes(const FuncGraphPtr& func_graph, irpb::GraphProt } } -void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* apply_map_ptr, - std::map* const_map_ptr, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *apply_map_ptr, + std::map *const_map_ptr, irpb::GraphProto *graph_proto) { if (func_graph == nullptr || node == nullptr || apply_map_ptr == nullptr || const_map_ptr == nullptr || graph_proto == nullptr) { return; @@ -435,12 +435,12 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& auto apply_idx = apply_map_ptr->size() + 1; (*apply_map_ptr)[node] = apply_idx; - auto& inputs = node->inputs(); + auto &inputs = node->inputs(); if (inputs.size() < 1) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } AnfNodePtr op = inputs[0]; - irpb::NodeProto* node_proto = graph_proto->add_node(); + irpb::NodeProto *node_proto = graph_proto->add_node(); // CNode/ConstGraph/Const/Parameter if (op->isa() || IsValueNode(op) || op->isa()) { @@ -452,7 +452,7 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& // process OP inputs for (size_t i = 1; i < inputs.size(); ++i) { - irpb::InputProto* input_proto = node_proto->add_input(); + irpb::InputProto *input_proto = node_proto->add_input(); input_proto->set_type(irpb::InputProto_EdgeType_DATA_EDGE); std::string id = GetOpNodeInputId(func_graph, inputs[i], *apply_map_ptr, const_map_ptr); input_proto->set_name(id); @@ -463,9 +463,9 @@ void ProtoExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& } } -void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const CNodePtr& ret_node, - const std::map& apply_map, - std::map* const_map_ptr, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr &func_graph, const CNodePtr &ret_node, + const std::map &apply_map, + std::map *const_map_ptr, irpb::GraphProto *graph_proto) { if (ret_node == nullptr || !ret_node->isa()) { MS_LOG(EXCEPTION) << "Graph return node is illegal"; } @@ -473,7 +473,7 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } - irpb::OutputProto* output_proto = graph_proto->add_outputs(); + irpb::OutputProto *output_proto = graph_proto->add_outputs(); if (output_proto == nullptr) { MS_LOG(EXCEPTION) << "output_proto is nullptr"; } @@ -482,22 +482,22 @@ void ProtoExporter::ExportFuncGraphOutput(const FuncGraphPtr& func_graph, const SetNodeOutputType(arg, output_proto->mutable_type()); } -static bool CompareValue(const std::pair& x, const std::pair& y) { +static bool CompareValue(const std::pair &x, const std::pair &y) { return x.second < y.second; } -void ProtoExporter::ExportValueNodes(const std::map& const_map, irpb::GraphProto* graph_proto) { +void ProtoExporter::ExportValueNodes(const std::map &const_map, irpb::GraphProto *graph_proto) { std::vector> nodes; (void)std::transform(const_map.cbegin(), const_map.cend(), std::back_inserter(nodes), - [](const std::pair& item) { return item; }); + [](const std::pair &item) { return item; }); sort(nodes.begin(), nodes.end(), CompareValue); - for (auto& item : nodes) { + for (auto &item : nodes) { if (graph_proto == nullptr) { MS_LOG(EXCEPTION) << "graph_proto is nullptr"; } - irpb::NamedValueProto* named_value = graph_proto->add_const_vals(); + irpb::NamedValueProto *named_value = graph_proto->add_const_vals(); MS_EXCEPTION_IF_NULL(named_value); named_value->set_key(GetConstNodeId(item.second)); SetValueToProto(GetValueNode(item.first), named_value->mutable_value()); @@ -506,7 +506,7 @@ void ProtoExporter::ExportValueNodes(const std::map& const_m void ProtoExporter::InitModelInfo() { model_.set_ir_version(irpb::IR_VERSION); } -std::string GetFuncGraphProtoString(const FuncGraphPtr& func_graph) { +std::string GetFuncGraphProtoString(const FuncGraphPtr &func_graph) { ProtoExporter exporter; return exporter.GetFuncGraphProtoString(func_graph); } diff --git a/mindspore/ccsrc/debug/e2e_dump.cc b/mindspore/ccsrc/debug/e2e_dump.cc index fbe76cdc47..34d401191a 100644 --- a/mindspore/ccsrc/debug/e2e_dump.cc +++ b/mindspore/ccsrc/debug/e2e_dump.cc @@ -36,7 +36,7 @@ Dump::Dump() dump_iter_(0), cur_iter_(0) {} -bool Dump::IsKernelNeedDump(const std::string& kernel_name) { +bool Dump::IsKernelNeedDump(const std::string &kernel_name) { if (dump_mode_ == 0) { // Dump All Kernels mode return true; @@ -49,7 +49,7 @@ bool Dump::IsKernelNeedDump(const std::string& kernel_name) { return false; } -bool Dump::ParseDumpConfig(const std::string& dump_config_file) { +bool Dump::ParseDumpConfig(const std::string &dump_config_file) { std::ifstream jsonFile(dump_config_file); if (!jsonFile.is_open()) { MS_LOG(ERROR) << dump_config_file << " open failed."; @@ -79,7 +79,7 @@ bool Dump::ParseDumpConfig(const std::string& dump_config_file) { return true; } -bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { +bool Dump::IsConfigExist(const nlohmann::json &dumpSettings) { if (dumpSettings.find("trans_flag") == dumpSettings.end() || dumpSettings.find("enable") == dumpSettings.end() || dumpSettings.find("mode") == dumpSettings.end() || dumpSettings.find("path") == dumpSettings.end() || dumpSettings.find("net_name") == dumpSettings.end() || dumpSettings.find("iteration") == dumpSettings.end() || @@ -91,7 +91,7 @@ bool Dump::IsConfigExist(const nlohmann::json& dumpSettings) { return true; } -bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { +bool Dump::IsConfigValid(const nlohmann::json &dumpSettings) { auto trans_flag = dumpSettings.at("trans_flag"); auto enable = dumpSettings.at("enable"); auto mode = dumpSettings.at("mode"); @@ -112,14 +112,14 @@ bool Dump::IsConfigValid(const nlohmann::json& dumpSettings) { dump_path_ = path; dump_net_name_ = net_name; dump_iter_ = iteration; - for (const auto& kernel : kernels) { + for (const auto &kernel : kernels) { dump_kernels_.push_back(kernel); } return true; } bool Dump::SetDumpConfFromJsonFile() { - const char* config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); + const char *config_path_str = std::getenv("MINDSPORE_CONFIG_PATH"); if (config_path_str != nullptr) { MS_LOG(INFO) << "Getenv MINDSPORE_CONFIG_PATH :" << config_path_str; } else { @@ -148,7 +148,7 @@ bool Dump::SetDumpConfFromJsonFile() { return ParseDumpConfig(dump_config_file); } -bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) { +bool Dump::DumpToFile(const std::string &filename, const void *data, size_t len) { if (filename.empty() || data == nullptr || len == 0) { MS_LOG(ERROR) << "Incorrect parameter."; return false; @@ -166,12 +166,12 @@ bool Dump::DumpToFile(const std::string& filename, const void* data, size_t len) MS_LOG(ERROR) << "Open file " << realpath << " fail."; return false; } - (void)fd.write(reinterpret_cast(data), SizeToLong(len)); + (void)fd.write(reinterpret_cast(data), SizeToLong(len)); fd.close(); return true; } -bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { +bool Dump::GetRealPath(const std::string &inpath, std::string *outpath) { MS_EXCEPTION_IF_NULL(outpath); auto path_split_pos = inpath.find_last_of('/'); if (path_split_pos == std::string::npos) { @@ -213,7 +213,7 @@ bool Dump::GetRealPath(const std::string& inpath, std::string* outpath) { return true; } -bool Dump::CreateNotExistDirs(const std::string& path) { +bool Dump::CreateNotExistDirs(const std::string &path) { std::shared_ptr fs = system::Env::GetFileSystem(); MS_EXCEPTION_IF_NULL(fs); char temp_path[PATH_MAX] = {0}; diff --git a/mindspore/ccsrc/debug/e2e_dump.h b/mindspore/ccsrc/debug/e2e_dump.h index 2410dfb09a..4c3e8308da 100644 --- a/mindspore/ccsrc/debug/e2e_dump.h +++ b/mindspore/ccsrc/debug/e2e_dump.h @@ -43,11 +43,11 @@ class Dump { uint32_t cur_iter() const { return cur_iter_; } - bool IsKernelNeedDump(const std::string& kernel_name); + bool IsKernelNeedDump(const std::string &kernel_name); bool SetDumpConfFromJsonFile(); - static bool DumpToFile(const std::string& filename, const void* data, size_t len); + static bool DumpToFile(const std::string &filename, const void *data, size_t len); protected: bool dump_enable_; @@ -59,14 +59,14 @@ class Dump { uint32_t cur_iter_; std::vector dump_kernels_; - static bool GetRealPath(const std::string& inpath, std::string* outpath); + static bool GetRealPath(const std::string &inpath, std::string *outpath); - static bool CreateNotExistDirs(const std::string& path); + static bool CreateNotExistDirs(const std::string &path); private: - bool ParseDumpConfig(const std::string& dump_config_file); - bool IsConfigExist(const nlohmann::json& dumpSettings); - bool IsConfigValid(const nlohmann::json& dumpSettings); + bool ParseDumpConfig(const std::string &dump_config_file); + bool IsConfigExist(const nlohmann::json &dumpSettings); + bool IsConfigValid(const nlohmann::json &dumpSettings); }; using DumpConfPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/debug/info.cc b/mindspore/ccsrc/debug/info.cc index 3c43bfa9b1..7903e554d9 100644 --- a/mindspore/ccsrc/debug/info.cc +++ b/mindspore/ccsrc/debug/info.cc @@ -23,7 +23,7 @@ #include "pipeline/parse/python_adapter.h" namespace mindspore { -std::string HighLightLine(const std::string& line, int col_begin, int col_end, SourceLineTip tip) { +std::string HighLightLine(const std::string &line, int col_begin, int col_end, SourceLineTip tip) { std::string temp_line = line; if (col_begin < col_end && col_begin != -1 && col_end <= SizeToInt(temp_line.length()) && tip != kSourceLineTipDiscard) { @@ -101,14 +101,14 @@ DebugInfo::DebugInfo() { name_ = ""; } -DebugInfo::DebugInfo(const std::string& name) { +DebugInfo::DebugInfo(const std::string &name) { InitValueFromContext(); unique_id_ = gen_unique_id(); debug_id_ = -1; name_ = name; } -DebugInfo::DebugInfo(const LocationPtr& loc) { +DebugInfo::DebugInfo(const LocationPtr &loc) { InitValueFromContext(); unique_id_ = gen_unique_id(); debug_id_ = -1; @@ -126,7 +126,7 @@ int64_t DebugInfo::debug_id() { } int64_t DebugInfo::unique_id_through_copy() const { - TraceInfoPtr trace_info = const_cast(this)->trace_info(); + TraceInfoPtr trace_info = const_cast(this)->trace_info(); if (trace_info != nullptr) { if (trace_info->isa() && trace_info->debug_info() != nullptr) { return trace_info->debug_info()->unique_id_through_copy(); @@ -172,7 +172,7 @@ LocationPtr GraphDebugInfo::location() { } return DebugInfo::location(); } -void GraphDebugInfo::set_deco_location(const LocationPtr& deco_list_loc) { deco_loc_ = deco_list_loc; } +void GraphDebugInfo::set_deco_location(const LocationPtr &deco_list_loc) { deco_loc_ = deco_list_loc; } TraceContextPtr TraceManager::CurrentContextInfo() { if (!TraceManager::trace_context_stack_.empty()) { @@ -181,18 +181,18 @@ TraceContextPtr TraceManager::CurrentContextInfo() { return nullptr; } -void TraceManager::DebugTrace(const std::string& func_name, const LocationPtr& location) { +void TraceManager::DebugTrace(const std::string &func_name, const LocationPtr &location) { TraceContextPtr context = std::make_shared(location); context->set_func_name(func_name); TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const LocationPtr& location) { +void TraceManager::DebugTrace(const LocationPtr &location) { TraceContextPtr context = std::make_shared(location); TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { +void TraceManager::DebugTrace(const TraceInfoPtr &trace_info) { if (trace_info == nullptr) { MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; } @@ -203,7 +203,7 @@ void TraceManager::DebugTrace(const TraceInfoPtr& trace_info) { TraceManager::trace_context_stack_.push(context); } -void TraceManager::DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info) { +void TraceManager::DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info) { if (trace_info == nullptr) { MS_LOG(EXCEPTION) << "DebugTrace wrong traced info is null"; } diff --git a/mindspore/ccsrc/debug/info.h b/mindspore/ccsrc/debug/info.h index da641ab74b..a34d6e3df5 100644 --- a/mindspore/ccsrc/debug/info.h +++ b/mindspore/ccsrc/debug/info.h @@ -37,9 +37,9 @@ enum SourceLineTip { kSourceLineTipDiscard = 0, kSourceLineTipNextLine = 1, kSou // Location class record the location in source code. class Location { public: - Location(const std::string& file_name, int line, int column, int line_end, int column_end) + Location(const std::string &file_name, int line, int column, int line_end, int column_end) : file_name_(file_name), line_(line), column_(column), line_end_(line_end), column_end_(column_end) {} - Location(const Location& loc) + Location(const Location &loc) : file_name_(loc.file_name_), line_(loc.line_), column_(loc.column_), @@ -77,21 +77,21 @@ class TraceManager { TraceManager() = default; ~TraceManager() = default; static TraceContextPtr CurrentContextInfo(); - static void DebugTrace(const std::string& func_name, const LocationPtr& location); - static void DebugTrace(const LocationPtr& location); - static void DebugTrace(const TraceInfoPtr& trace_info); + static void DebugTrace(const std::string &func_name, const LocationPtr &location); + static void DebugTrace(const LocationPtr &location); + static void DebugTrace(const TraceInfoPtr &trace_info); // debug trace with a cloned trace info with debug_info - static void DebugTrace(const DebugInfoPtr& debug_info, const TraceInfoPtr& trace_info); + static void DebugTrace(const DebugInfoPtr &debug_info, const TraceInfoPtr &trace_info); static void EndTrace(); static std::stack trace_context_stack_; }; class TraceGuard { public: - explicit TraceGuard(const std::string func_name, const LocationPtr& location) { + explicit TraceGuard(const std::string func_name, const LocationPtr &location) { TraceManager::DebugTrace(func_name, location); } - explicit TraceGuard(const LocationPtr& location) { TraceManager::DebugTrace(location); } + explicit TraceGuard(const LocationPtr &location) { TraceManager::DebugTrace(location); } ~TraceGuard() { TraceManager::EndTrace(); } }; @@ -106,23 +106,23 @@ class TraceContext { public: ~TraceContext() = default; - explicit TraceContext(const LocationPtr& loc) { + explicit TraceContext(const LocationPtr &loc) { ProcessAttributeFromContext(); location_ = loc; } - explicit TraceContext(const std::string& func_name) { + explicit TraceContext(const std::string &func_name) { ProcessAttributeFromContext(); func_name_ = func_name; } - explicit TraceContext(const TraceInfoPtr& trace_info) { + explicit TraceContext(const TraceInfoPtr &trace_info) { ProcessAttributeFromContext(); trace_info_ = trace_info; } - void set_location(const LocationPtr& loc) { location_ = loc; } + void set_location(const LocationPtr &loc) { location_ = loc; } LocationPtr location() { return location_; } - void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } TraceInfoPtr trace_info() { return trace_info_; } - void set_func_name(const std::string& func_name) { func_name_ = func_name; } + void set_func_name(const std::string &func_name) { func_name_ = func_name; } std::string func_name() { return func_name_; } }; @@ -130,9 +130,9 @@ class DebugInfo : public Base { public: DebugInfo(); - explicit DebugInfo(const std::string& name); + explicit DebugInfo(const std::string &name); - explicit DebugInfo(const LocationPtr& loc); + explicit DebugInfo(const LocationPtr &loc); virtual ~DebugInfo() = default; MS_DECLARE_PARENT(DebugInfo, Base); @@ -141,12 +141,12 @@ class DebugInfo : public Base { int64_t unique_id_through_copy() const; std::string get_id() { return std::to_string(debug_id()); } - void set_trace_info(const TraceInfoPtr& trace_info) { trace_info_ = trace_info; } + void set_trace_info(const TraceInfoPtr &trace_info) { trace_info_ = trace_info; } TraceInfoPtr trace_info() { return trace_info_; } - void set_location(const LocationPtr& loc) { location_ = loc; } + void set_location(const LocationPtr &loc) { location_ = loc; } virtual LocationPtr location() { return location_; } std::string name() { return name_; } - void set_name(const std::string& name) { name_ = name; } + void set_name(const std::string &name) { name_ = name; } virtual std::string debug_name(); virtual std::string get_python_func_belonged() { return ""; } @@ -186,7 +186,7 @@ class NodeDebugInfo : public DebugInfo { py_func_belonged_ = context_info->func_name(); } } - explicit NodeDebugInfo(const std::string& name) : DebugInfo(name) { + explicit NodeDebugInfo(const std::string &name) : DebugInfo(name) { if (TraceManager::CurrentContextInfo() != nullptr) { auto context_info = TraceManager::CurrentContextInfo(); py_func_belonged_ = context_info->func_name(); @@ -195,9 +195,9 @@ class NodeDebugInfo : public DebugInfo { ~NodeDebugInfo() override = default; std::string debug_name() override; - void set_node(const std::shared_ptr& node) { node_ = AnfNodeWeakPtr(node); } + void set_node(const std::shared_ptr &node) { node_ = AnfNodeWeakPtr(node); } std::shared_ptr get_node() const { return node_.lock(); } - void set_py_func_belonged(const std::string& name) { py_func_belonged_ = name; } + void set_py_func_belonged(const std::string &name) { py_func_belonged_ = name; } std::string get_python_func_belonged() override { return py_func_belonged_; } AnfNodeWeakPtr node_; std::string py_func_belonged_; @@ -214,7 +214,7 @@ class GraphDebugInfo : public DebugInfo { } } - explicit GraphDebugInfo(const std::string& name) : DebugInfo(name) { + explicit GraphDebugInfo(const std::string &name) : DebugInfo(name) { if (TraceManager::CurrentContextInfo() != nullptr) { auto context_info = TraceManager::CurrentContextInfo(); py_func_name_ = context_info->func_name(); @@ -225,11 +225,11 @@ class GraphDebugInfo : public DebugInfo { std::string debug_name() override; LocationPtr location() override; LocationPtr deco_location() { return deco_loc_; } - void set_graph(const FuncGraphPtr& func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } + void set_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); } FuncGraphPtr get_graph() const { return func_graph_.lock(); } - void set_full_name(const std::string& name) { full_name_ = name; } + void set_full_name(const std::string &name) { full_name_ = name; } std::string get_full_name() { return full_name_; } - void set_deco_location(const LocationPtr& deco_list_loc); + void set_deco_location(const LocationPtr &deco_list_loc); std::string get_python_func_belonged() override { return py_func_name_; } FuncGraphWeakPtr func_graph_; LocationPtr deco_loc_; diff --git a/mindspore/ccsrc/debug/label.cc b/mindspore/ccsrc/debug/label.cc index f0e16e831e..d8c4986482 100644 --- a/mindspore/ccsrc/debug/label.cc +++ b/mindspore/ccsrc/debug/label.cc @@ -31,7 +31,7 @@ struct NameWithTrace { std::string name; std::vector trace_labels; }; -static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType trace_label) { +static std::string GetTraceName(const TraceInfoPtr &trace_info, TraceLabelType trace_label) { switch (trace_label) { case TraceLabelType::kShortSymbol: return trace_info->symbol(); @@ -42,7 +42,7 @@ static std::string GetTraceName(const TraceInfoPtr& trace_info, TraceLabelType t } } -NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_label) { +NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_label) { NameWithTrace trace_name; // find debug info after Resolve/ExpandJ/GenMetaFuncGraph, it is a new node auto temp_info = debug_info; @@ -66,9 +66,9 @@ NameWithTrace RootName(const DebugInfoPtr& debug_info, TraceLabelType trace_labe return trace_name; } -std::string CombineTraceTypes(const std::string& root_name, const std::vector& trace_labels) { +std::string CombineTraceTypes(const std::string &root_name, const std::vector &trace_labels) { std::string tags = ""; - for (auto& itr : trace_labels) { + for (auto &itr : trace_labels) { std::string symbol = itr; tags = tags + symbol; } @@ -76,12 +76,12 @@ std::string CombineTraceTypes(const std::string& root_name, const std::vector GetSourceCodeDebugInfoVec(DebugInfoPtr debug_info) { return debug_with_loc_vec; } -DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { +DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info) { auto debug_with_loc_vec = GetSourceCodeDebugInfoVec(info); if (debug_with_loc_vec.size() > 0) { return debug_with_loc_vec[0]; @@ -78,7 +78,7 @@ DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info) { } } -std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { if (info == nullptr) { return ""; } @@ -91,7 +91,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { // a trace info identifies a node transform, so we can trace the node transform through // a link of trace info and debug info -std::string GetInfoWithAction(const std::vector& info_vec, SourceLineTip tip) { +std::string GetInfoWithAction(const std::vector &info_vec, SourceLineTip tip) { if (info_vec.size() < 1) { return ""; } @@ -109,7 +109,7 @@ std::string GetInfoWithAction(const std::vector& info_vec, SourceL return traced_info; } -std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { +std::string GetTracedDebugInfo(const DebugInfoPtr &info, SourceLineTip tip) { if (info == nullptr) { return ""; } @@ -124,7 +124,7 @@ std::string GetTracedDebugInfo(const DebugInfoPtr& info, SourceLineTip tip) { return ""; } -std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, SourceLineTip tip) { +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip) { std::ostringstream oss; if (info == nullptr) { return ""; @@ -139,7 +139,7 @@ std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, So return oss.str(); } -std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBasePtrList args_spec_list) { +std::string GetGraphParamString(const FuncGraphPtr &graph, abstract::AbstractBasePtrList args_spec_list) { std::ostringstream oss; oss << "graph:" << graph->ToString() << " with args["; auto params = graph->parameters(); @@ -151,8 +151,8 @@ std::string GetGraphParamString(const FuncGraphPtr& graph, abstract::AbstractBas return oss.str(); } -void DumpInferStack(std::ostringstream& oss) { - auto& infer_stack = GetCurrenGraphInferStack(); +void DumpInferStack(std::ostringstream &oss) { + auto &infer_stack = GetCurrenGraphInferStack(); if (infer_stack.empty()) { return; } @@ -164,7 +164,7 @@ void DumpInferStack(std::ostringstream& oss) { } std::reverse(infer_vec.begin(), infer_vec.end()); int index = 0; - for (auto& item : infer_vec) { + for (auto &item : infer_vec) { auto graph_infer = std::dynamic_pointer_cast(item.first); if (graph_infer == nullptr) { MS_LOG(WARNING) << "DumpInferStack failed, got null graph evaluator"; @@ -183,7 +183,7 @@ void DumpInferStack(std::ostringstream& oss) { } void TraceGraphInfer() { - auto& infer_stack = GetCurrenGraphInferStack(); + auto &infer_stack = GetCurrenGraphInferStack(); std::ostringstream oss; if (infer_stack.empty()) { return; @@ -200,15 +200,15 @@ class AnalyzedFuncGraphExporter : public AnfExporter { AnalyzedFuncGraphExporter() : AnfExporter("", true, false) {} ~AnalyzedFuncGraphExporter() override = default; - void ExportFuncGraph(const std::string& filename, const std::vector& node_cfgs); + void ExportFuncGraph(const std::string &filename, const std::vector &node_cfgs); private: - std::string GetNodeType(const AnfNodePtr& nd) override; + std::string GetNodeType(const AnfNodePtr &nd) override; }; std::unordered_map CalcTaggedFuncGraphs() { std::unordered_map tagged_func_graphs; - auto& list = GetCNodeDebugStack(); + auto &list = GetCNodeDebugStack(); for (size_t i = 0; i < list.size(); ++i) { auto node_cfg = list[i]; auto fg = node_cfg->context()->func_graph(); @@ -223,7 +223,7 @@ void OutputAnalyzedGraphWithType() { exporter.ExportFuncGraph("analyze_fail.dat", GetCNodeDebugStack()); } -std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { +std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr &node) { if (node_cfg_ == nullptr) { return AnfExporter::GetNodeType(node); } @@ -248,8 +248,8 @@ std::string AnalyzedFuncGraphExporter::GetNodeType(const AnfNodePtr& node) { return oss.str(); } -void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, - const std::vector& node_cfgs) { +void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string &filename, + const std::vector &node_cfgs) { if (node_cfgs.empty()) { MS_LOG(DEBUG) << "Node configs is empty"; return; @@ -265,7 +265,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, auto tagged_func_graphs = CalcTaggedFuncGraphs(); // first output graph on the analysis stack - for (const auto& node_cfg : node_cfgs) { + for (const auto &node_cfg : node_cfgs) { auto fg = node_cfg->context()->func_graph(); // the graph is already output, skip it if (exported.find(fg) != exported.end()) { @@ -296,7 +296,7 @@ void AnalyzedFuncGraphExporter::ExportFuncGraph(const std::string& filename, ofs.close(); } -void GetInferStackInfo(std::ostringstream& oss) { +void GetInferStackInfo(std::ostringstream &oss) { MS_LOG(INFO) << "Get graph analysis information begin"; auto stack = GetCNodeDebugStack(); if (stack.empty()) { @@ -336,7 +336,7 @@ void GetInferStackInfo(std::ostringstream& oss) { static std::stack> graph_infer_stack; // trace the cnode infer debug info static std::vector cnode_debug_stack{}; -void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node) { +void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node) { if (eval == nullptr) { MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; } @@ -345,7 +345,7 @@ void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::An } } -void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { +void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval) { if (eval == nullptr) { MS_LOG(EXCEPTION) << "GraphInferEnter got null eval"; } @@ -354,13 +354,13 @@ void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval) { } } -void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg) { cnode_debug_stack.push_back(node_cfg); } +void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg) { cnode_debug_stack.push_back(node_cfg); } void TraceInferCNodeLeave() { cnode_debug_stack.pop_back(); } -std::vector& GetCNodeDebugStack() { return cnode_debug_stack; } +std::vector &GetCNodeDebugStack() { return cnode_debug_stack; } -std::stack>& GetCurrenGraphInferStack() { +std::stack> &GetCurrenGraphInferStack() { return graph_infer_stack; } void ClearTraceStack() { diff --git a/mindspore/ccsrc/debug/trace.h b/mindspore/ccsrc/debug/trace.h index 5fba86fddd..2704a80a35 100644 --- a/mindspore/ccsrc/debug/trace.h +++ b/mindspore/ccsrc/debug/trace.h @@ -31,19 +31,19 @@ namespace mindspore { namespace trace { -std::string GetDebugInfo(const DebugInfoPtr& info, SourceLineTip tip = kSourceLineTipNextLine); -std::string GetDebugInfo(const DebugInfoPtr& info, const std::string& prefix, +std::string GetDebugInfo(const DebugInfoPtr &info, SourceLineTip tip = kSourceLineTipNextLine); +std::string GetDebugInfo(const DebugInfoPtr &info, const std::string &prefix, SourceLineTip tip = kSourceLineTipNextLine); -DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr& info); +DebugInfoPtr GetSourceCodeDebugInfo(const DebugInfoPtr &info); void TraceGraphInfer(); -void GetInferStackInfo(std::ostringstream& oss); -void TraceGraphInferEnter(const abstract::EvaluatorPtr& eval, const abstract::AnfNodeConfigPtr& node); -void TraceGraphInferLeave(const abstract::EvaluatorPtr& eval); -void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr& node_cfg); +void GetInferStackInfo(std::ostringstream &oss); +void TraceGraphInferEnter(const abstract::EvaluatorPtr &eval, const abstract::AnfNodeConfigPtr &node); +void TraceGraphInferLeave(const abstract::EvaluatorPtr &eval); +void TraceInferCNodeEnter(const abstract::AnfNodeConfigPtr &node_cfg); void TraceInferCNodeLeave(); -std::vector& GetCNodeDebugStack(); -std::stack>& GetCurrenGraphInferStack(); -std::string GetAbstractStr(const abstract::AbstractBasePtr& abs); +std::vector &GetCNodeDebugStack(); +std::stack> &GetCurrenGraphInferStack(); +std::string GetAbstractStr(const abstract::AbstractBasePtr &abs); void ClearTraceStack(); } // namespace trace } // namespace mindspore diff --git a/mindspore/ccsrc/debug/trace_info.cc b/mindspore/ccsrc/debug/trace_info.cc index b01cd15010..19358e197a 100644 --- a/mindspore/ccsrc/debug/trace_info.cc +++ b/mindspore/ccsrc/debug/trace_info.cc @@ -23,7 +23,7 @@ #include "pipeline/parse/python_adapter.h" namespace mindspore { -std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr& info) { +std::string TraceInfo::GetActionBetweenNode(const DebugInfoPtr &info) { if (info == nullptr) { return ""; } diff --git a/mindspore/ccsrc/debug/trace_info.h b/mindspore/ccsrc/debug/trace_info.h index 16be9031e2..e7a8c83dad 100644 --- a/mindspore/ccsrc/debug/trace_info.h +++ b/mindspore/ccsrc/debug/trace_info.h @@ -40,13 +40,13 @@ using DebugInfoPtr = std::shared_ptr; // namespace to support intermediate representation definition class TraceInfo : public Base { public: - TraceInfo(const DebugInfoPtr& info, const std::string& full_name, const std::string& symbol) { + TraceInfo(const DebugInfoPtr &info, const std::string &full_name, const std::string &symbol) { symbol_ = symbol; full_name_ = full_name; name_ = full_name_; debug_info_ = info; } - TraceInfo(const TraceInfo& info) + TraceInfo(const TraceInfo &info) : Base(), debug_info_(info.debug_info_), symbol_(info.symbol_), full_name_(info.full_name_), name_(info.name_) {} virtual ~TraceInfo() = default; MS_DECLARE_PARENT(TraceInfo, Base); @@ -55,8 +55,8 @@ class TraceInfo : public Base { virtual std::string full_name() { return full_name_; } virtual TraceInfoPtr clone() { return shared_from_base(); } virtual std::string action_name() { return ""; } - virtual std::string GetActionBetweenNode(const DebugInfoPtr& info); - void set_debug_info(const DebugInfoPtr& info) { debug_info_ = info; } + virtual std::string GetActionBetweenNode(const DebugInfoPtr &info); + void set_debug_info(const DebugInfoPtr &info) { debug_info_ = info; } DebugInfoPtr debug_info() { return debug_info_; } DebugInfoPtr DebugInfoHasLoc(); std::vector> GetSourceCodeDebugInfo(); @@ -70,7 +70,7 @@ class TraceInfo : public Base { class TracePhi : public TraceInfo { public: - explicit TracePhi(const DebugInfoPtr& info) : TraceInfo(info, "phi", "Φ") {} + explicit TracePhi(const DebugInfoPtr &info) : TraceInfo(info, "phi", "Φ") {} MS_DECLARE_PARENT(TracePhi, TraceInfo); ~TracePhi() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -78,8 +78,8 @@ class TracePhi : public TraceInfo { class TraceIfStmtTrueBranch : public TraceInfo { public: - TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch&) = default; - explicit TraceIfStmtTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_true", "✓") {} + TraceIfStmtTrueBranch(const TraceIfStmtTrueBranch &) = default; + explicit TraceIfStmtTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_true", "✓") {} MS_DECLARE_PARENT(TraceIfStmtTrueBranch, TraceInfo); ~TraceIfStmtTrueBranch() override = default; TraceInfoPtr clone() override { @@ -89,8 +89,8 @@ class TraceIfStmtTrueBranch : public TraceInfo { class TraceIfStmtFalseBranch : public TraceInfo { public: - TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch&) = default; - explicit TraceIfStmtFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_false", "✗") {} + TraceIfStmtFalseBranch(const TraceIfStmtFalseBranch &) = default; + explicit TraceIfStmtFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_false", "✗") {} MS_DECLARE_PARENT(TraceIfStmtFalseBranch, TraceInfo); ~TraceIfStmtFalseBranch() override = default; TraceInfoPtr clone() override { @@ -100,7 +100,7 @@ class TraceIfStmtFalseBranch : public TraceInfo { class TraceIfStmtAfterBranch : public TraceInfo { public: - explicit TraceIfStmtAfterBranch(const DebugInfoPtr& info) : TraceInfo(info, "if_after", "↓") {} + explicit TraceIfStmtAfterBranch(const DebugInfoPtr &info) : TraceInfo(info, "if_after", "↓") {} MS_DECLARE_PARENT(TraceIfStmtAfterBranch, TraceInfo); ~TraceIfStmtAfterBranch() override = default; TraceInfoPtr clone() override { @@ -110,7 +110,7 @@ class TraceIfStmtAfterBranch : public TraceInfo { class TraceIfExpTrueBranch : public TraceInfo { public: - explicit TraceIfExpTrueBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_true", "↰") {} + explicit TraceIfExpTrueBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_true", "↰") {} MS_DECLARE_PARENT(TraceIfExpTrueBranch, TraceInfo); ~TraceIfExpTrueBranch() override = default; TraceInfoPtr clone() override { @@ -120,7 +120,7 @@ class TraceIfExpTrueBranch : public TraceInfo { class TraceIfExpFalseBranch : public TraceInfo { public: - explicit TraceIfExpFalseBranch(const DebugInfoPtr& info) : TraceInfo(info, "ifexp_false", "↱") {} + explicit TraceIfExpFalseBranch(const DebugInfoPtr &info) : TraceInfo(info, "ifexp_false", "↱") {} MS_DECLARE_PARENT(TraceIfExpFalseBranch, TraceInfo); ~TraceIfExpFalseBranch() override = default; TraceInfoPtr clone() override { @@ -131,7 +131,7 @@ class TraceIfExpFalseBranch : public TraceInfo { class TraceCopy : public TraceInfo { public: TraceCopy() : TraceInfo(nullptr, "copy", "") {} - explicit TraceCopy(const DebugInfoPtr& info) : TraceInfo(info, "copy", "") {} + explicit TraceCopy(const DebugInfoPtr &info) : TraceInfo(info, "copy", "") {} MS_DECLARE_PARENT(TraceCopy, TraceInfo); ~TraceCopy() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -139,7 +139,7 @@ class TraceCopy : public TraceInfo { class TraceIterator : public TraceInfo { public: - explicit TraceIterator(const DebugInfoPtr& info) : TraceInfo(info, "iterator", "@") {} + explicit TraceIterator(const DebugInfoPtr &info) : TraceInfo(info, "iterator", "@") {} MS_DECLARE_PARENT(TraceIterator, TraceInfo); ~TraceIterator() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -147,7 +147,7 @@ class TraceIterator : public TraceInfo { class TraceWhileHeader : public TraceInfo { public: - explicit TraceWhileHeader(const DebugInfoPtr& info) : TraceInfo(info, "while_header", "⤾") {} + explicit TraceWhileHeader(const DebugInfoPtr &info) : TraceInfo(info, "while_header", "⤾") {} MS_DECLARE_PARENT(TraceWhileHeader, TraceInfo); ~TraceWhileHeader() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -155,7 +155,7 @@ class TraceWhileHeader : public TraceInfo { class TraceWhileBody : public TraceInfo { public: - explicit TraceWhileBody(const DebugInfoPtr& info) : TraceInfo(info, "while_body", "⥁") {} + explicit TraceWhileBody(const DebugInfoPtr &info) : TraceInfo(info, "while_body", "⥁") {} MS_DECLARE_PARENT(TraceWhileBody, TraceInfo); ~TraceWhileBody() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -163,7 +163,7 @@ class TraceWhileBody : public TraceInfo { class TraceWhileAfter : public TraceInfo { public: - explicit TraceWhileAfter(const DebugInfoPtr& info) : TraceInfo(info, "while_after", "↓") {} + explicit TraceWhileAfter(const DebugInfoPtr &info) : TraceInfo(info, "while_after", "↓") {} MS_DECLARE_PARENT(TraceWhileAfter, TraceInfo); ~TraceWhileAfter() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -171,7 +171,7 @@ class TraceWhileAfter : public TraceInfo { class TraceForHeader : public TraceInfo { public: - explicit TraceForHeader(const DebugInfoPtr& info) : TraceInfo(info, "for_header", "⤾") {} + explicit TraceForHeader(const DebugInfoPtr &info) : TraceInfo(info, "for_header", "⤾") {} MS_DECLARE_PARENT(TraceForHeader, TraceInfo); ~TraceForHeader() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -179,7 +179,7 @@ class TraceForHeader : public TraceInfo { class TraceForBody : public TraceInfo { public: - explicit TraceForBody(const DebugInfoPtr& info) : TraceInfo(info, "for_body", "⥁") {} + explicit TraceForBody(const DebugInfoPtr &info) : TraceInfo(info, "for_body", "⥁") {} MS_DECLARE_PARENT(TraceForBody, TraceInfo); ~TraceForBody() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -187,7 +187,7 @@ class TraceForBody : public TraceInfo { class TraceForAfter : public TraceInfo { public: - explicit TraceForAfter(const DebugInfoPtr& info) : TraceInfo(info, "for_after", "↓") {} + explicit TraceForAfter(const DebugInfoPtr &info) : TraceInfo(info, "for_after", "↓") {} MS_DECLARE_PARENT(TraceForAfter, TraceInfo); ~TraceForAfter() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -195,7 +195,7 @@ class TraceForAfter : public TraceInfo { class TraceEquiv : public TraceInfo { public: - explicit TraceEquiv(const DebugInfoPtr& info) : TraceInfo(info, "equiv", "equiv") {} + explicit TraceEquiv(const DebugInfoPtr &info) : TraceInfo(info, "equiv", "equiv") {} MS_DECLARE_PARENT(TraceEquiv, TraceInfo); ~TraceEquiv() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -204,7 +204,7 @@ class TraceEquiv : public TraceInfo { class TraceGradFpropApp : public TraceInfo { public: TraceGradFpropApp() : TraceInfo(nullptr, "grad_fprop_app", "▲") {} - explicit TraceGradFpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop_app", "▲") {} + explicit TraceGradFpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop_app", "▲") {} MS_DECLARE_PARENT(TraceGradFpropApp, TraceInfo); ~TraceGradFpropApp() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -213,7 +213,7 @@ class TraceGradFpropApp : public TraceInfo { class TraceGradBpropApp : public TraceInfo { public: TraceGradBpropApp() : TraceInfo(nullptr, "grad_bprop_app", "▼") {} - explicit TraceGradBpropApp(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop_app", "▼") {} + explicit TraceGradBpropApp(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop_app", "▼") {} MS_DECLARE_PARENT(TraceGradBpropApp, TraceInfo); ~TraceGradBpropApp() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -222,7 +222,7 @@ class TraceGradBpropApp : public TraceInfo { class TraceGradFprop : public TraceInfo { public: TraceGradFprop() : TraceInfo(nullptr, "grad_fprop", "▶") {} - explicit TraceGradFprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_fprop", "▶") {} + explicit TraceGradFprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_fprop", "▶") {} MS_DECLARE_PARENT(TraceGradFprop, TraceInfo); ~TraceGradFprop() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -231,7 +231,7 @@ class TraceGradFprop : public TraceInfo { class TraceGradBprop : public TraceInfo { public: TraceGradBprop() : TraceInfo(nullptr, "grad_bprop", "◀") {} - explicit TraceGradBprop(const DebugInfoPtr& info) : TraceInfo(info, "grad_bprop", "◀") {} + explicit TraceGradBprop(const DebugInfoPtr &info) : TraceInfo(info, "grad_bprop", "◀") {} MS_DECLARE_PARENT(TraceGradBprop, TraceInfo); ~TraceGradBprop() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -240,7 +240,7 @@ class TraceGradBprop : public TraceInfo { class TraceGradSens : public TraceInfo { public: TraceGradSens() : TraceInfo(nullptr, "grad_sens", "∇") {} - explicit TraceGradSens(const DebugInfoPtr& info) : TraceInfo(info, "grad_sens", "∇") {} + explicit TraceGradSens(const DebugInfoPtr &info) : TraceInfo(info, "grad_sens", "∇") {} MS_DECLARE_PARENT(TraceGradSens, TraceInfo); ~TraceGradSens() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -248,7 +248,7 @@ class TraceGradSens : public TraceInfo { class TraceSpecialize : public TraceInfo { public: - explicit TraceSpecialize(const std::string& counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } + explicit TraceSpecialize(const std::string &counter) : TraceInfo(nullptr, "specialize", "") { counter_ = counter; } MS_DECLARE_PARENT(TraceSpecialize, TraceInfo); std::string name() override { return full_name_ + counter_; } std::string symbol() override { return counter_ + "_"; } @@ -260,7 +260,7 @@ class TraceSpecialize : public TraceInfo { class TraceGradOperation : public TraceInfo { public: - explicit TraceGradOperation(const DebugInfoPtr& info) : TraceInfo(info, "grad_ops", "") {} + explicit TraceGradOperation(const DebugInfoPtr &info) : TraceInfo(info, "grad_ops", "") {} MS_DECLARE_PARENT(TraceGradOperation, TraceInfo); ~TraceGradOperation() override = default; TraceInfoPtr clone() override { @@ -270,7 +270,7 @@ class TraceGradOperation : public TraceInfo { class TraceForceBool : public TraceInfo { public: - explicit TraceForceBool(const DebugInfoPtr& info) : TraceInfo(info, "force_bool", "") {} + explicit TraceForceBool(const DebugInfoPtr &info) : TraceInfo(info, "force_bool", "") {} MS_DECLARE_PARENT(TraceForceBool, TraceInfo); ~TraceForceBool() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -278,7 +278,7 @@ class TraceForceBool : public TraceInfo { class TraceExpandJ : public TraceInfo { public: - explicit TraceExpandJ(const DebugInfoPtr& info) : TraceInfo(info, "expand_j", "") {} + explicit TraceExpandJ(const DebugInfoPtr &info) : TraceInfo(info, "expand_j", "") {} MS_DECLARE_PARENT(TraceExpandJ, TraceInfo); ~TraceExpandJ() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -286,7 +286,7 @@ class TraceExpandJ : public TraceInfo { class TraceGenMetaFuncGraph : public TraceInfo { public: - explicit TraceGenMetaFuncGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenMetaFuncGraph", "") {} + explicit TraceGenMetaFuncGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenMetaFuncGraph", "") {} MS_DECLARE_PARENT(TraceGenMetaFuncGraph, TraceInfo); ~TraceGenMetaFuncGraph() override = default; TraceInfoPtr clone() override { @@ -296,7 +296,7 @@ class TraceGenMetaFuncGraph : public TraceInfo { class TraceEvaluatorGenGraph : public TraceInfo { public: - explicit TraceEvaluatorGenGraph(const DebugInfoPtr& info) : TraceInfo(info, "GenEvaluatorGraph", "") {} + explicit TraceEvaluatorGenGraph(const DebugInfoPtr &info) : TraceInfo(info, "GenEvaluatorGraph", "") {} MS_DECLARE_PARENT(TraceEvaluatorGenGraph, TraceInfo); ~TraceEvaluatorGenGraph() override = default; TraceInfoPtr clone() override { @@ -306,7 +306,7 @@ class TraceEvaluatorGenGraph : public TraceInfo { class TraceResolve : public TraceInfo { public: - explicit TraceResolve(const DebugInfoPtr& info) : TraceInfo(info, "resolve", "") {} + explicit TraceResolve(const DebugInfoPtr &info) : TraceInfo(info, "resolve", "") {} MS_DECLARE_PARENT(TraceResolve, TraceInfo); ~TraceResolve() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -315,7 +315,7 @@ class TraceResolve : public TraceInfo { class TraceTransform : public TraceInfo { public: TraceTransform() : TraceInfo(nullptr, "transform", "") { transform_name_ = ""; } - explicit TraceTransform(const std::string& transform_name) : TraceInfo(nullptr, "transform", "") { + explicit TraceTransform(const std::string &transform_name) : TraceInfo(nullptr, "transform", "") { transform_name_ = transform_name; } @@ -335,7 +335,7 @@ class TraceTransform : public TraceInfo { class TraceGenerateVarArg : public TraceInfo { public: - explicit TraceGenerateVarArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateVarArg", "") {} + explicit TraceGenerateVarArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateVarArg", "") {} MS_DECLARE_PARENT(TraceGenerateVarArg, TraceInfo); ~TraceGenerateVarArg() override = default; TraceInfoPtr clone() override { @@ -345,7 +345,7 @@ class TraceGenerateVarArg : public TraceInfo { class TraceGenerateKwArg : public TraceInfo { public: - explicit TraceGenerateKwArg(const DebugInfoPtr& info) : TraceInfo(info, "GenerateKwArg", "") {} + explicit TraceGenerateKwArg(const DebugInfoPtr &info) : TraceInfo(info, "GenerateKwArg", "") {} MS_DECLARE_PARENT(TraceGenerateKwArg, TraceInfo); ~TraceGenerateKwArg() override = default; TraceInfoPtr clone() override { @@ -355,7 +355,7 @@ class TraceGenerateKwArg : public TraceInfo { class TraceTrasformK : public TraceInfo { public: - explicit TraceTrasformK(const DebugInfoPtr& info) : TraceInfo(info, "TraceTrasformK", "") {} + explicit TraceTrasformK(const DebugInfoPtr &info) : TraceInfo(info, "TraceTrasformK", "") {} MS_DECLARE_PARENT(TraceTrasformK, TraceInfo); ~TraceTrasformK() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -363,7 +363,7 @@ class TraceTrasformK : public TraceInfo { class TracePartialTransform : public TraceInfo { public: - explicit TracePartialTransform(const DebugInfoPtr& info) : TraceInfo(info, "PartialTransform", "") {} + explicit TracePartialTransform(const DebugInfoPtr &info) : TraceInfo(info, "PartialTransform", "") {} MS_DECLARE_PARENT(TracePartialTransform, TraceInfo); ~TracePartialTransform() override = default; TraceInfoPtr clone() override { @@ -373,7 +373,7 @@ class TracePartialTransform : public TraceInfo { class TraceGetEnv : public TraceInfo { public: - explicit TraceGetEnv(const DebugInfoPtr& info) : TraceInfo(info, "get_env", "") {} + explicit TraceGetEnv(const DebugInfoPtr &info) : TraceInfo(info, "get_env", "") {} MS_DECLARE_PARENT(TraceGetEnv, TraceInfo); ~TraceGetEnv() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -381,7 +381,7 @@ class TraceGetEnv : public TraceInfo { class TraceDoSignature : public TraceInfo { public: - explicit TraceDoSignature(const DebugInfoPtr& info) : TraceInfo(info, "DoSignature", "") {} + explicit TraceDoSignature(const DebugInfoPtr &info) : TraceInfo(info, "DoSignature", "") {} MS_DECLARE_PARENT(TraceDoSignature, TraceInfo); ~TraceDoSignature() override = default; TraceInfoPtr clone() override { return std::make_shared(*shared_from_base()); } @@ -390,7 +390,7 @@ class TraceDoSignature : public TraceInfo { class TraceCombileLikeGraphs : public TraceInfo { public: TraceCombileLikeGraphs() : TraceInfo(nullptr, "CombileLike", "L-") {} - explicit TraceCombileLikeGraphs(const DebugInfoPtr& info) : TraceInfo(info, "CombileLike", "L-") {} + explicit TraceCombileLikeGraphs(const DebugInfoPtr &info) : TraceInfo(info, "CombileLike", "L-") {} MS_DECLARE_PARENT(TraceCombileLikeGraphs, TraceInfo); ~TraceCombileLikeGraphs() override = default; TraceInfoPtr clone() override { diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc index 2c38e4290d..69c6dca576 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc +++ b/mindspore/ccsrc/device/ascend/ascend_memory_pool.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace device { namespace ascend { -size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { +size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { if (has_malloc_) { MS_LOG(EXCEPTION) << "Has alloc memory pool memory !"; } @@ -37,7 +37,7 @@ size_t AscendMemoryPool::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { return size; } -bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr& addr) { +bool AscendMemoryPool::FreeDeviceMem(const DeviceMemPtr &addr) { MS_EXCEPTION_IF_NULL(addr); has_malloc_ = false; free_mem_size_ = total_mem_size_; @@ -53,7 +53,7 @@ size_t AscendMemoryPool::AlignMemorySize(size_t size) const { size_t AscendMemoryPool::mem_alloc_unit_size() const { return free_mem_size_ - 512; } -void AscendMemoryPool::set_device_mem_pool_base(uint8_t* device_mem_pool_base) { +void AscendMemoryPool::set_device_mem_pool_base(uint8_t *device_mem_pool_base) { MS_EXCEPTION_IF_NULL(device_mem_pool_base); device_mem_pool_base_ = device_mem_pool_base; } diff --git a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h index a02bd453b2..7fa3ebc23e 100644 --- a/mindspore/ccsrc/device/ascend/ascend_memory_pool.h +++ b/mindspore/ccsrc/device/ascend/ascend_memory_pool.h @@ -26,12 +26,12 @@ namespace ascend { class AscendMemoryPool : public DynamicMemPoolBestFit { public: ~AscendMemoryPool() override = default; - AscendMemoryPool(const AscendMemoryPool&) = delete; - AscendMemoryPool& operator=(const AscendMemoryPool&) = delete; + AscendMemoryPool(const AscendMemoryPool &) = delete; + AscendMemoryPool &operator=(const AscendMemoryPool &) = delete; - size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; - bool FreeDeviceMem(const DeviceMemPtr& addr) override; - void set_device_mem_pool_base(uint8_t* device_mem_pool_base); + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; + void set_device_mem_pool_base(uint8_t *device_mem_pool_base); void set_device_mem_pool_size(uint64_t device_mem_pool_size) { device_mem_pool_size_ = device_mem_pool_size; free_mem_size_ = device_mem_pool_size_; @@ -40,7 +40,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { size_t free_mem_size() override; size_t total_mem_size() override; - static AscendMemoryPool& GetInstance() { + static AscendMemoryPool &GetInstance() { static AscendMemoryPool instance; return instance; } @@ -54,7 +54,7 @@ class AscendMemoryPool : public DynamicMemPoolBestFit { private: AscendMemoryPool() = default; bool has_malloc_{false}; - uint8_t* device_mem_pool_base_{nullptr}; + uint8_t *device_mem_pool_base_{nullptr}; uint64_t device_mem_pool_size_{0}; size_t free_mem_size_{0}; size_t total_mem_size_{0}; diff --git a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h index f7804a8ee7..9f4ea4d667 100755 --- a/mindspore/ccsrc/device/ascend/ascend_stream_assign.h +++ b/mindspore/ccsrc/device/ascend/ascend_stream_assign.h @@ -39,13 +39,13 @@ using std::vector; class AscendStreamAssign { public: - static AscendStreamAssign& GetInstance() { + static AscendStreamAssign &GetInstance() { static AscendStreamAssign instance; // Guaranteed to be destroyed. return instance; } - AscendStreamAssign(const AscendStreamAssign&) = delete; - AscendStreamAssign& operator=(const AscendStreamAssign&) = delete; + AscendStreamAssign(const AscendStreamAssign &) = delete; + AscendStreamAssign &operator=(const AscendStreamAssign &) = delete; uint32_t GetTotalStreamNum() const; // new stream policy @@ -53,19 +53,19 @@ class AscendStreamAssign { uint32_t total_independ_stream_num() const { return total_independ_stream_num_; } uint32_t total_event_num() const { return total_event_num_; } - void InsertActiveNew(const std::shared_ptr& graph_ptr); - void AssignAllNodesStream(const std::shared_ptr& graph_ptr); + void InsertActiveNew(const std::shared_ptr &graph_ptr); + void AssignAllNodesStream(const std::shared_ptr &graph_ptr); void ResetNew(); - void AssignStreamNew(const std::shared_ptr& graph_ptr); - bool IsIndependentNode(const CNodePtr& node_ptr); - const std::unordered_map& logic_to_independent_map() { return logic_to_independent_map_; } - const std::unordered_map& logic_to_physic_map() { return logic_to_physic_map_; } - const std::vector>& inner_parallel_streams() { return inner_parallel_streams_; } - void GetWaitStreams(vector* wait_active_stream_list); - const std::vector& hcom_streams() { return hcom_stream_list_; } - CNodePtr CreateSendApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, + void AssignStreamNew(const std::shared_ptr &graph_ptr); + bool IsIndependentNode(const CNodePtr &node_ptr); + const std::unordered_map &logic_to_independent_map() { return logic_to_independent_map_; } + const std::unordered_map &logic_to_physic_map() { return logic_to_physic_map_; } + const std::vector> &inner_parallel_streams() { return inner_parallel_streams_; } + void GetWaitStreams(vector *wait_active_stream_list); + const std::vector &hcom_streams() { return hcom_stream_list_; } + CNodePtr CreateSendApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id); - CNodePtr CreateRecvApplyKernel(const std::shared_ptr& graph_ptr, uint32_t event_id, + CNodePtr CreateRecvApplyKernel(const std::shared_ptr &graph_ptr, uint32_t event_id, uint32_t stream_id); private: @@ -73,30 +73,30 @@ class AscendStreamAssign { ~AscendStreamAssign() = default; vector::iterator FindTargetOp(vector::iterator begin, vector::iterator end, - const CNodePtr& node); + const CNodePtr &node); - bool IsHcom(const CNodePtr& apply_kernel); + bool IsHcom(const CNodePtr &apply_kernel); bool IsProcessed(uint32_t logic_id); - void TransLogicToPhysic(const vector& logic_ids, vector* physic_ids); - void AssignCommonStreamId(const CNodePtr& cur_cnode_ptr, CNodePtr* pre_cnode_ptr, uint32_t* cur_index, - uint32_t* cur_stream_id); + void TransLogicToPhysic(const vector &logic_ids, vector *physic_ids); + void AssignCommonStreamId(const CNodePtr &cur_cnode_ptr, CNodePtr *pre_cnode_ptr, uint32_t *cur_index, + uint32_t *cur_stream_id); void RecordIdMap(uint32_t logic_id, uint32_t physic_id); - void UpdateStreamActive(const CNodePtr& active_ptr); - void UpdateStreamSwitch(const CNodePtr& switch_ptr, const CNodePtr& active_ptr); + void UpdateStreamActive(const CNodePtr &active_ptr); + void UpdateStreamSwitch(const CNodePtr &switch_ptr, const CNodePtr &active_ptr); bool IsTaskSink(); - void AssignIndependentStreamId(const CNodePtr& cur_cnode_ptr, uint32_t deal_logic_id); - void UpdateStreamId(const std::shared_ptr& graph_ptr); - void UpdateEventId(const std::shared_ptr& graph_ptr); - void PrintGraphExeOrders(const std::shared_ptr& graph_ptr); - void RecordFirstCommonOp(const CNodePtr& cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); - uint32_t GetLogicId(const CNodePtr& cur_cnode_ptr); + void AssignIndependentStreamId(const CNodePtr &cur_cnode_ptr, uint32_t deal_logic_id); + void UpdateStreamId(const std::shared_ptr &graph_ptr); + void UpdateEventId(const std::shared_ptr &graph_ptr); + void PrintGraphExeOrders(const std::shared_ptr &graph_ptr); + void RecordFirstCommonOp(const CNodePtr &cur_cnode_ptr, uint32_t cur_node_logic_id, uint32_t cur_stream_id); + uint32_t GetLogicId(const CNodePtr &cur_cnode_ptr); void SetCommonStreamNum(uint32_t cur_stream_id); - void FindAllReduceParallel(const std::shared_ptr& graph_ptr); + void FindAllReduceParallel(const std::shared_ptr &graph_ptr); bool IsProcessedParallelStream(uint32_t stream_id); - void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector* parallel_streams); - void InsertSendRecvForIndependent(const std::shared_ptr& graph_ptr); - void InsertSendRecvForHcomParallel(const std::shared_ptr& graph_ptr); - void GetNeedActiveStreams(const std::shared_ptr& graph_ptr); + void GetParallelStream(uint32_t cur_stream_id, uint32_t stream_acitve_id, std::vector *parallel_streams); + void InsertSendRecvForIndependent(const std::shared_ptr &graph_ptr); + void InsertSendRecvForHcomParallel(const std::shared_ptr &graph_ptr); + void GetNeedActiveStreams(const std::shared_ptr &graph_ptr); uint32_t total_common_stream_num_{0}; uint32_t total_independ_stream_num_{0}; diff --git a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h index 668b54b78c..bf4977bf9a 100644 --- a/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h +++ b/mindspore/ccsrc/device/ascend/profiling/plugin_impl.h @@ -28,14 +28,14 @@ namespace device { namespace ascend { class PluginImpl : public PluginIntf { public: - explicit PluginImpl(const std::string& module); + explicit PluginImpl(const std::string &module); ~PluginImpl() override = default; - int Init(const Reporter* reporter) override; + int Init(const Reporter *reporter) override; int UnInit() override; - static Reporter* GetPluginReporter() { return reporter_; } + static Reporter *GetPluginReporter() { return reporter_; } private: - static Reporter* reporter_; + static Reporter *reporter_; std::string module_; }; } // namespace ascend diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc index 3a1dc4689b..cbecb3030d 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.cc @@ -20,12 +20,12 @@ namespace mindspore { namespace device { namespace ascend { -PluginIntf* ProfilingEngineImpl::CreatePlugin() { +PluginIntf *ProfilingEngineImpl::CreatePlugin() { MS_LOG(INFO) << "Create Plugin."; return new (std::nothrow) PluginImpl("Framework"); } -int ProfilingEngineImpl::ReleasePlugin(PluginIntf* plugin) { +int ProfilingEngineImpl::ReleasePlugin(PluginIntf *plugin) { if (plugin != nullptr) { delete plugin; } diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h index e8dbfc7087..c7cbc4b7dd 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_engine_impl.h @@ -29,8 +29,8 @@ class ProfilingEngineImpl : public EngineIntf { ProfilingEngineImpl() = default; ~ProfilingEngineImpl() override = default; - PluginIntf* CreatePlugin() override; - int ReleasePlugin(PluginIntf* plugin) override; + PluginIntf *CreatePlugin() override; + int ReleasePlugin(PluginIntf *plugin) override; }; } // namespace ascend } // namespace device diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc index 29193e5cfa..c3f622ffee 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc @@ -35,7 +35,7 @@ using Json = nlohmann::json; namespace mindspore { namespace device { namespace ascend { -ProfilingManager& ProfilingManager::GetInstance() { +ProfilingManager &ProfilingManager::GetInstance() { static ProfilingManager inst; return inst; } @@ -45,11 +45,11 @@ ProfilingManager::ProfilingManager() : device_id_(0), prof_handle_(nullptr) { } uint64_t ProfilingManager::GetJobId() const { - const char* job_id = std::getenv("JOB_ID"); + const char *job_id = std::getenv("JOB_ID"); return ((job_id != nullptr) ? std::strtoul(job_id, nullptr, 10) : 0); } -bool ProfilingManager::ReportProfilingData(const map& op_taskId_map) const { +bool ProfilingManager::ReportProfilingData(const map &op_taskId_map) const { if (!IsProfiling()) { MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; return false; @@ -66,10 +66,10 @@ bool ProfilingManager::ReportProfilingData(const map& op_taskI MS_LOG(INFO) << "DistributeTask: op tasId map size = " << op_taskId_map.size(); Msprof::Engine::ReporterData reporter_data = {}; - for (const auto& iter : op_taskId_map) { + for (const auto &iter : op_taskId_map) { auto data = iter.second + ' ' + std::to_string(iter.first) + ';'; reporter_data.deviceId = UintToInt(device_id_); - reporter_data.data = (unsigned char*)(const_cast(data.c_str())); + reporter_data.data = (unsigned char *)(const_cast(data.c_str())); reporter_data.dataLen = data.size(); auto ret = memcpy_s(reporter_data.tag, MSPROF_ENGINE_MAX_TAG_LEN + 1, "framework", sizeof("framework")); if (ret != 0) { @@ -85,7 +85,7 @@ bool ProfilingManager::ReportProfilingData(const map& op_taskI return true; } -static std::vector Split(const std::string& str, const char delim) { +static std::vector Split(const std::string &str, const char delim) { std::vector elems; if (str.empty()) { @@ -116,7 +116,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) { device_id_ = device_id; // exp: export PROFILING_MODE=true // export PROFILING_OPTIONS=training_trace - const char* prof_options_str = std::getenv("PROFILING_OPTIONS"); + const char *prof_options_str = std::getenv("PROFILING_OPTIONS"); // register Framework to profiling int result = Msprof::Engine::RegisterEngine("Framework", engine_0_.get()); if (result != 0) { @@ -176,7 +176,7 @@ bool ProfilingManager::StopProfiling() const { MS_LOG(INFO) << "No need profiling. please export PROFILING_MODE and in train mode."; return true; } - Msprof::Engine::Reporter* reporter = PluginImpl::GetPluginReporter(); + Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); if (reporter != nullptr) { MS_LOG(INFO) << "report data end, ret = " << reporter->Flush(); } diff --git a/mindspore/ccsrc/device/gpu/blocking_queue.h b/mindspore/ccsrc/device/gpu/blocking_queue.h index ccf481858f..a1594c21a9 100644 --- a/mindspore/ccsrc/device/gpu/blocking_queue.h +++ b/mindspore/ccsrc/device/gpu/blocking_queue.h @@ -33,27 +33,27 @@ enum BlockQueueStatus_T : int { SUCCESS = 0, QUEUE_NOT_EXIST, HANDLE_NOT_EXIST, class GpuQueue { public: - GpuQueue(void* addr, size_t feature_size, size_t label_size, size_t capacity); + GpuQueue(void *addr, size_t feature_size, size_t label_size, size_t capacity); virtual ~GpuQueue(); - void RegisterRelease(const std::function& func) { host_release_ = func; } + void RegisterRelease(const std::function &func) { host_release_ = func; } inline bool IsEmpty() const { return head_ == tail_; } inline bool IsFull() const { return head_ == ((tail_ + 1) % (capacity_)); } - BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size); - BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size) const; + BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size); + BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size) const; BlockQueueStatus_T Pop(); bool Destroy(); private: struct NodeInfo { std::unique_ptr event_; - void* host_feature_addr_; - void* host_label_addr_; + void *host_feature_addr_; + void *host_label_addr_; }; - void* buffer_; + void *buffer_; size_t head_; size_t tail_; size_t feature_size_; @@ -61,10 +61,10 @@ class GpuQueue { size_t capacity_; cudaStream_t stream_; std::unique_ptr node_info_; - std::function host_release_; + std::function host_release_; - GpuQueue(const GpuQueue&) = delete; - GpuQueue& operator=(const GpuQueue&) = delete; + GpuQueue(const GpuQueue &) = delete; + GpuQueue &operator=(const GpuQueue &) = delete; }; class BlockingQueue { @@ -72,11 +72,11 @@ class BlockingQueue { BlockingQueue() : queue_(nullptr) {} ~BlockingQueue() = default; - BlockQueueStatus_T Create(void* addr, size_t feature_size, size_t label_size, size_t capacity); - void RegisterRelease(const std::function& func); - BlockQueueStatus_T Push(void* feature_addr, size_t feature_size, void* label_addr, size_t label_size, + BlockQueueStatus_T Create(void *addr, size_t feature_size, size_t label_size, size_t capacity); + void RegisterRelease(const std::function &func); + BlockQueueStatus_T Push(void *feature_addr, size_t feature_size, void *label_addr, size_t label_size, unsigned int timeout_in_sec); - BlockQueueStatus_T Front(void** feature_addr, size_t* feature_size, void** label_addr, size_t* label_size); + BlockQueueStatus_T Front(void **feature_addr, size_t *feature_size, void **label_addr, size_t *label_size); BlockQueueStatus_T Pop(); bool Destroy(); diff --git a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc index d212c56ae7..d7ab95bbe8 100644 --- a/mindspore/ccsrc/device/gpu/distribution/collective_init.cc +++ b/mindspore/ccsrc/device/gpu/distribution/collective_init.cc @@ -20,17 +20,17 @@ namespace mindspore { namespace device { namespace gpu { -CollectiveInitializer& CollectiveInitializer::instance() { +CollectiveInitializer &CollectiveInitializer::instance() { static CollectiveInitializer instance = {}; return instance; } bool CollectiveInitializer::collective_inited() const { return collective_inited_; } -const void* CollectiveInitializer::collective_handle() const { return collective_handle_; } +const void *CollectiveInitializer::collective_handle() const { return collective_handle_; } void CollectiveInitializer::InitCollective() { - void* handle = dlopen("libgpu_collective.so", RTLD_LAZY); + void *handle = dlopen("libgpu_collective.so", RTLD_LAZY); if (handle == nullptr) { MS_LOG(EXCEPTION) << "Loading libgpu_collective.so failed. Many reasons could cause this:\n1.libgpu_collective.so is not " diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc index b25ba2906b..e505fdc218 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.cc @@ -50,13 +50,13 @@ void GPUDeviceManager::ReleaseDevice() { CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator"); } -bool GPUDeviceManager::CreateStream(DeviceStream* stream) { +bool GPUDeviceManager::CreateStream(DeviceStream *stream) { CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream"); gpu_streams_.emplace_back(*stream); return true; } -const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; } +const DeviceStream &GPUDeviceManager::default_stream() const { return default_stream_; } int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); } @@ -76,17 +76,17 @@ uint32_t GPUDeviceManager::cur_device_id() const { return cur_dev_id_; } bool GPUDeviceManager::is_device_id_init() const { return dev_id_init_; } -const cudnnHandle_t& GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } +const cudnnHandle_t &GPUDeviceManager::GetCudnnHandle() const { return cudnn_handle_; } -const cublasHandle_t& GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } +const cublasHandle_t &GPUDeviceManager::GetCublasHandle() const { return cublas_handle_; } -bool GPUDeviceManager::SyncStream(const DeviceStream& stream) const { return CudaDriver::SyncStream(stream); } +bool GPUDeviceManager::SyncStream(const DeviceStream &stream) const { return CudaDriver::SyncStream(stream); } -bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const { +bool GPUDeviceManager::CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const { return CudaDriver::CopyDeviceMemToHost(dst, src, size); } -bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const { +bool GPUDeviceManager::CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const { return CudaDriver::CopyHostMemToDevice(dst, src, size); } } // namespace gpu diff --git a/mindspore/ccsrc/device/gpu/gpu_device_manager.h b/mindspore/ccsrc/device/gpu/gpu_device_manager.h index 3b3d2aecb5..a546b999a4 100644 --- a/mindspore/ccsrc/device/gpu/gpu_device_manager.h +++ b/mindspore/ccsrc/device/gpu/gpu_device_manager.h @@ -37,17 +37,17 @@ class GPUDeviceManager { uint32_t cur_device_id() const; bool is_device_id_init() const; - bool CreateStream(DeviceStream* stream); - bool SyncStream(const DeviceStream& stream) const; - const DeviceStream& default_stream() const; + bool CreateStream(DeviceStream *stream); + bool SyncStream(const DeviceStream &stream) const; + const DeviceStream &default_stream() const; - const cudnnHandle_t& GetCudnnHandle() const; - const cublasHandle_t& GetCublasHandle() const; + const cudnnHandle_t &GetCudnnHandle() const; + const cublasHandle_t &GetCublasHandle() const; - bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const; - bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const; + bool CopyDeviceMemToHost(const HostMemPtr &dst, const DeviceMemPtr &src, size_t size) const; + bool CopyHostMemToDevice(const DeviceMemPtr &dst, const void *src, size_t size) const; - static GPUDeviceManager& GetInstance() { + static GPUDeviceManager &GetInstance() { static GPUDeviceManager instance; return instance; } @@ -55,8 +55,8 @@ class GPUDeviceManager { private: GPUDeviceManager() : dev_id_init_(false), cur_dev_id_(0) {} ~GPUDeviceManager() = default; - GPUDeviceManager(const GPUDeviceManager&) = delete; - GPUDeviceManager& operator=(const GPUDeviceManager&) = delete; + GPUDeviceManager(const GPUDeviceManager &) = delete; + GPUDeviceManager &operator=(const GPUDeviceManager &) = delete; // default CUDA stream used for all the kernels. DeviceStream default_stream_{nullptr}; diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc index cbd43645ab..3a1a53c600 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc +++ b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.cc @@ -43,14 +43,14 @@ bool GPUMemoryAllocator::Finalize() { return true; } -bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr* addr) { +bool GPUMemoryAllocator::AllocBufferQueueMem(size_t size, DeviceMemPtr *addr) { auto alloc_size = AllocDeviceMem(size, addr); buffer_q_addr_ = *addr; // Buffer queue needs to ensure that the alloc_size and size is equal. return (alloc_size == size) ? true : false; } -size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { +size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr *addr) { if (size == 0) { MS_LOG(EXCEPTION) << "The memory alloc size is 0."; } @@ -68,7 +68,7 @@ size_t GPUMemoryAllocator::AllocDeviceMem(size_t size, DeviceMemPtr* addr) { return alloc_size; } -bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr& addr) { return CudaDriver::FreeDeviceMem(addr); } +bool GPUMemoryAllocator::FreeDeviceMem(const DeviceMemPtr &addr) { return CudaDriver::FreeDeviceMem(addr); } size_t GPUMemoryAllocator::free_mem_size() { return CudaDriver::free_mem_size(); } diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h index 0d2f0f8a39..36374bfaad 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h +++ b/mindspore/ccsrc/device/gpu/gpu_memory_allocator.h @@ -29,22 +29,22 @@ class GPUMemoryAllocator : public DynamicMemPoolBestFit { ~GPUMemoryAllocator() override = default; bool Init(); bool Finalize(); - bool AllocBufferQueueMem(size_t size, DeviceMemPtr* addr); + bool AllocBufferQueueMem(size_t size, DeviceMemPtr *addr); - size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) override; - bool FreeDeviceMem(const DeviceMemPtr& addr) override; + size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) override; + bool FreeDeviceMem(const DeviceMemPtr &addr) override; size_t free_mem_size() override; size_t total_mem_size() override; - static GPUMemoryAllocator& GetInstance() { + static GPUMemoryAllocator &GetInstance() { static GPUMemoryAllocator instance; return instance; } private: GPUMemoryAllocator() = default; - GPUMemoryAllocator(const GPUMemoryAllocator&) = delete; - GPUMemoryAllocator& operator=(const GPUMemoryAllocator&) = delete; + GPUMemoryAllocator(const GPUMemoryAllocator &) = delete; + GPUMemoryAllocator &operator=(const GPUMemoryAllocator &) = delete; // Used to track address of data buffer queue. DeviceMemPtr buffer_q_addr_{nullptr}; diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc index 05ecf380d1..6ccb4c8cde 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.cc +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.cc @@ -33,8 +33,8 @@ namespace gpu { using AnfAlgo = mindspore::session::AnfRuntimeAlgorithm; using mindspore::kernel::KernelBuildInfo; namespace { -bool CheckKernelInfo(const std::shared_ptr& alternative_kernel_info, - const std::shared_ptr& selected_kernel_info) { +bool CheckKernelInfo(const std::shared_ptr &alternative_kernel_info, + const std::shared_ptr &selected_kernel_info) { MS_EXCEPTION_IF_NULL(selected_kernel_info); MS_EXCEPTION_IF_NULL(alternative_kernel_info); size_t selected_input_num = selected_kernel_info->GetInputNum(); @@ -67,7 +67,7 @@ bool CheckKernelInfo(const std::shared_ptr& alternative_kernel_ return true; } -std::string SupportedTypeList(const CNodePtr& kernel_node) { +std::string SupportedTypeList(const CNodePtr &kernel_node) { std::string supported_type_lists = kernel::GpuKernelFactory::GetInstance().SupportedTypeList(AnfAlgo::GetCNodeName(kernel_node)); if (!supported_type_lists.empty()) { @@ -91,7 +91,7 @@ std::string SupportedTypeList(const CNodePtr& kernel_node) { return supported_type_lists; } -bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr& selected_kernel_info) { +bool SelectAkgKernel(const CNodePtr &kernel_node, const std::shared_ptr &selected_kernel_info) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(selected_kernel_info); std::vector> kernel_info_list; @@ -110,7 +110,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptr& alternative_kernel_info) { + [&](const std::shared_ptr &alternative_kernel_info) { return CheckKernelInfo(alternative_kernel_info, selected_kernel_info); }); if (!match) { @@ -120,7 +120,7 @@ bool SelectAkgKernel(const CNodePtr& kernel_node, const std::shared_ptrinput(input_index + 1); @@ -153,7 +153,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo& selected_kernel_info, co } } // namespace -void SetKernelInfo(const CNodePtr& kernel_node) { +void SetKernelInfo(const CNodePtr &kernel_node) { std::vector inputs_format; std::vector inputs_type; std::shared_ptr builder = diff --git a/mindspore/ccsrc/device/gpu/kernel_info_setter.h b/mindspore/ccsrc/device/gpu/kernel_info_setter.h index e3dc2241a9..b351f74fa3 100644 --- a/mindspore/ccsrc/device/gpu/kernel_info_setter.h +++ b/mindspore/ccsrc/device/gpu/kernel_info_setter.h @@ -27,7 +27,7 @@ namespace mindspore { namespace device { namespace gpu { -void SetKernelInfo(const CNodePtr& apply_kernel_ptr); +void SetKernelInfo(const CNodePtr &apply_kernel_ptr); class KernelAttr { public: @@ -35,24 +35,24 @@ class KernelAttr { KernelAttr() : all_same_(false) {} ~KernelAttr() = default; - KernelAttr& AddInputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { + KernelAttr &AddInputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { input_type_.emplace_back(ms_type, format); return *this; } - KernelAttr& AddOutputAttr(const TypeId& ms_type, const std::string& format = kOpFormat_DEFAULT) { + KernelAttr &AddOutputAttr(const TypeId &ms_type, const std::string &format = kOpFormat_DEFAULT) { output_type_.emplace_back(ms_type, format); return *this; } - KernelAttr& AddAllSameAttr(const bool& all_same) { + KernelAttr &AddAllSameAttr(const bool &all_same) { all_same_ = all_same; return *this; } - const DataType& GetInputAttr(const size_t index) const { return input_type_[index]; } - const DataType& GetOutputAttr(const size_t index) const { return output_type_[index]; } - const bool& GetAllSame() const { return all_same_; } + const DataType &GetInputAttr(const size_t index) const { return input_type_[index]; } + const DataType &GetOutputAttr(const size_t index) const { return output_type_[index]; } + const bool &GetAllSame() const { return all_same_; } size_t GetInputSize() const { return input_type_.size(); } size_t GetOutputSize() const { return output_type_.size(); } diff --git a/mindspore/ccsrc/gvar/typeid_manager.cc b/mindspore/ccsrc/gvar/typeid_manager.cc index 97250a6571..f40052411a 100644 --- a/mindspore/ccsrc/gvar/typeid_manager.cc +++ b/mindspore/ccsrc/gvar/typeid_manager.cc @@ -24,7 +24,7 @@ namespace mindspore { -struct TypeIdManager* TypeIdManager::Get() { +struct TypeIdManager *TypeIdManager::Get() { static TypeIdManager manager; return &manager; } diff --git a/mindspore/ccsrc/ir/anf.cc b/mindspore/ccsrc/ir/anf.cc index 658fb578b7..dd86e46713 100644 --- a/mindspore/ccsrc/ir/anf.cc +++ b/mindspore/ccsrc/ir/anf.cc @@ -35,14 +35,14 @@ TypePtr AnfNode::Type() const { return (abstract_ == nullptr) ? nullptr : abstra BaseShapePtr AnfNode::Shape() const { return (abstract_ == nullptr) ? nullptr : abstract_->BuildShape(); } std::string AnfNode::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } -CNode::CNode(const std::vector& inputs, const FuncGraphPtr& func_graph) +CNode::CNode(const std::vector &inputs, const FuncGraphPtr &func_graph) : AnfNode(func_graph), inputs_(inputs), stop_gradient_(false) {} // Check if CNode is an apply with the specific Primitive. -bool CNode::IsApply(const PrimitivePtr& value) const { +bool CNode::IsApply(const PrimitivePtr &value) const { if (value == nullptr) { return false; } @@ -57,7 +57,7 @@ bool CNode::IsApply(const PrimitivePtr& value) const { return false; } -void CNode::set_input(size_t i, const AnfNodePtr& new_input) { inputs_[i] = new_input; } +void CNode::set_input(size_t i, const AnfNodePtr &new_input) { inputs_[i] = new_input; } std::string CNode::DebugString(int recursive_level) const { std::ostringstream buffer; @@ -68,7 +68,7 @@ std::string CNode::DebugString(int recursive_level) const { buffer << ToString() << "{"; bool is_first_node = true; int idx = 0; - for (auto& node : inputs_) { + for (auto &node : inputs_) { MS_EXCEPTION_IF_NULL(node); if (is_first_node) { is_first_node = false; @@ -85,7 +85,7 @@ std::string CNode::DebugString(int recursive_level) const { return buffer.str(); } -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr& operator_info) { +OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { if (operator_info_ != nullptr) { MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() << ", using the new one: " << operator_info->name(); @@ -173,11 +173,11 @@ std::string ValueNode::fullname_with_scope() { return fullname_with_scope_; } -void CNode::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } -void ValueNode::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } -void Parameter::accept(AnfVisitor* v) { v->Visit(shared_from_base()); } +void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } +void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base()); } -bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { +bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if (cnode != nullptr) { @@ -186,7 +186,7 @@ bool IsPrimitiveCNode(const AnfNodePtr& node, const PrimitivePtr& value) { return false; } -PrimitivePtr GetCNodePrimitive(const AnfNodePtr& node) { +PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) { if (node == nullptr) { return nullptr; } @@ -217,7 +217,7 @@ std::string GetCNodeFuncName(const CNodePtr cnode) { return ""; } -bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { +bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value) { if (IsValueNode(node)) { PrimitivePtr fn_value = GetValueNode(node); MS_EXCEPTION_IF_NULL(value); @@ -229,7 +229,7 @@ bool IsPrimitive(const AnfNodePtr& node, const PrimitivePtr& value) { } namespace id_generator { static std::unordered_map node_ids; -std::string get_id(const AnfNodePtr& node) { +std::string get_id(const AnfNodePtr &node) { auto type_name = node->type_name(); if (node_ids.find(type_name) == node_ids.end()) { node_ids[type_name] = 0; diff --git a/mindspore/ccsrc/ir/base.h b/mindspore/ccsrc/ir/base.h index 6a3537306f..7ccef13876 100644 --- a/mindspore/ccsrc/ir/base.h +++ b/mindspore/ccsrc/ir/base.h @@ -39,15 +39,15 @@ struct is_shared_ptr> : public std::true_type {}; class Base : public std::enable_shared_from_this { public: constexpr Base() = default; - Base(const Base& other) : std::enable_shared_from_this(other) {} - virtual bool operator==(const Base& rhs) { + Base(const Base &other) : std::enable_shared_from_this(other) {} + virtual bool operator==(const Base &rhs) { if (this == &rhs) { return true; } return false; } - virtual Base& operator=(const Base&) { return *this; } + virtual Base &operator=(const Base &) { return *this; } virtual ~Base() = default; virtual std::size_t hash() const { return tid(); } virtual std::string ToString() const { return type_name(); } @@ -57,14 +57,14 @@ class Base : public std::enable_shared_from_this { virtual const bool IsFromTypeId(uint32_t tid) const; virtual std::string type_name() const { return "Base"; } - static uint32_t GetTypeId(const char* const type_key); + static uint32_t GetTypeId(const char *const type_key); virtual uint32_t tid() const { static const uint32_t tid = GetTypeId(typeid(Base).name()); return tid; } template ::value && std::is_base_of::value, T>::type* = nullptr> + typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> inline bool isa() const { static const uint32_t tid = GetTypeId(typeid(T).name()); return this->IsFromTypeId(tid); @@ -90,9 +90,9 @@ using BasePtr = std::shared_ptr; using BaseWeakPtr = std::weak_ptr; template -inline T* cast(U* source) { +inline T *cast(U *source) { if (source != nullptr && source->template isa()) { - return static_cast(source); + return static_cast(source); } else { return nullptr; } @@ -100,7 +100,7 @@ inline T* cast(U* source) { template < typename T, typename U, - typename std::enable_if::value && std::is_base_of::value, T>::type* = nullptr> + typename std::enable_if::value && std::is_base_of::value, T>::type * = nullptr> inline std::shared_ptr dyn_cast(const std::shared_ptr r) { if (r != nullptr && r->template isa()) { return std::static_pointer_cast(r); @@ -143,7 +143,7 @@ struct MS_EXPORT TypeIdManager { std::mutex mutex; std::atomic type_counter{0}; std::unordered_map map; - static TypeIdManager* Get(); + static TypeIdManager *Get(); TypeIdManager() : mutex(), type_counter(0), map() {} }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/ccsrc/ir/dtype.cc index 65a42bc3fa..a6ef99177c 100644 --- a/mindspore/ccsrc/ir/dtype.cc +++ b/mindspore/ccsrc/ir/dtype.cc @@ -48,11 +48,11 @@ std::string Keyword::ToString() const { return buffer.str(); } -bool Keyword::operator==(const Type& other) const { +bool Keyword::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_keyword = static_cast(other); + const auto &other_keyword = static_cast(other); return (other_keyword.key_ == key_ && *other_keyword.value_ == *value_); } @@ -87,11 +87,11 @@ std::string Slice::ToString() const { return buffer.str(); } -bool Slice::operator==(const Type& other) const { +bool Slice::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_slice = static_cast(other); + auto other_slice = static_cast(other); return (*start_ == *other_slice.start_ && *stop_ == *other_slice.stop_ && *step_ == *other_slice.step_); } @@ -122,11 +122,11 @@ std::string TensorType::DumpText() const { } } -bool TensorType::operator==(const Type& other) const { +bool TensorType::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_elem_type = static_cast(other).element_type_; + auto other_elem_type = static_cast(other).element_type_; // When element_type_ = nullptr, which means any type of Array. if (element_type_ == nullptr && other_elem_type == nullptr) { return true; @@ -141,7 +141,7 @@ Function::Function() : Object(kObjectTypeFunction) { retval_ = nullptr; } -Function::Function(const std::vector& args, const TypePtr retval) +Function::Function(const std::vector &args, const TypePtr retval) : Object(kObjectTypeFunction, false), args_(args), retval_(retval) {} TypePtr Function::DeepCopy() const { @@ -151,7 +151,7 @@ TypePtr Function::DeepCopy() const { TypePtrList args; TypePtr retval = nullptr; (void)std::transform(args_.begin(), args_.end(), std::back_inserter(args), - [](const TypePtr& arg) { return arg->DeepCopy(); }); + [](const TypePtr &arg) { return arg->DeepCopy(); }); if (retval_ != nullptr) { retval = retval_->DeepCopy(); } @@ -159,12 +159,12 @@ TypePtr Function::DeepCopy() const { } } -bool Function::operator==(const Type& other) const { +bool Function::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_function = static_cast(other); + const auto &other_function = static_cast(other); if ((retval_ != nullptr) && (other_function.retval_ != nullptr)) { if (*retval_ != *other_function.retval_) { return false; @@ -188,7 +188,7 @@ std::string Function::ToString() const { } else { buffer << "Func[("; bool begin = true; - for (auto& attr : args_) { + for (auto &attr : args_) { if (!begin) { buffer << ", "; } else { @@ -242,34 +242,34 @@ std::string JTagged::DumpText() const { return buffer.str(); } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr problem) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem) { MS_EXCEPTION_IF_NULL(problem); os << problem->ToString(); return os; } -std::size_t TypeHasher::operator()(TypePtr const& type) const { +std::size_t TypeHasher::operator()(TypePtr const &type) const { MS_EXCEPTION_IF_NULL(type); std::size_t hash = std::hash()(type->type_id()); return hash; } -std::size_t TypeListHasher::operator()(const TypePtrList& type_list) const { +std::size_t TypeListHasher::operator()(const TypePtrList &type_list) const { std::size_t hash_sum = 0; - for (auto& type : type_list) { + for (auto &type : type_list) { auto type_id = static_cast(type->type_id()); hash_sum = hash_combine(hash_sum, type_id); } return hash_sum; } -bool TypeEqual::operator()(TypePtr const& t1, TypePtr const& t2) const { +bool TypeEqual::operator()(TypePtr const &t1, TypePtr const &t2) const { MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t2); return t1->type_id() == t2->type_id(); } -bool TypeListEqual::operator()(TypePtrList const& lhs, TypePtrList const& rhs) const { +bool TypeListEqual::operator()(TypePtrList const &lhs, TypePtrList const &rhs) const { if (lhs.size() != rhs.size()) { return false; } @@ -332,7 +332,7 @@ TypePtr TypeIdToType(TypeId id) { namespace { template -TypePtr StringToNumberType(const std::string& type_name, const std::string& num_type_name) { +TypePtr StringToNumberType(const std::string &type_name, const std::string &num_type_name) { TypePtr type = nullptr; if (type_name == num_type_name) { type = std::make_shared(); @@ -344,14 +344,14 @@ TypePtr StringToNumberType(const std::string& type_name, const std::string& num_ } auto bits = std::stoi(type_name.substr(num_type_name.size())); type = std::make_shared(bits); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << num_type_name << " convert from string error " << e.what(); } } return type; } -std::vector StringToVectorOfType(const std::string& type_names) { +std::vector StringToVectorOfType(const std::string &type_names) { std::vector types; if (type_names.length() == 0) { return types; @@ -371,7 +371,7 @@ std::vector StringToVectorOfType(const std::string& type_names) { return types; } -TypePtr TensorStrToType(const std::string& type_name) { +TypePtr TensorStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Tensor") { type = std::make_shared(); @@ -388,7 +388,7 @@ TypePtr TensorStrToType(const std::string& type_name) { return nullptr; } type = std::make_shared(element_type); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -396,7 +396,7 @@ TypePtr TensorStrToType(const std::string& type_name) { return type; } -TypePtr ListStrToType(const std::string& type_name) { +TypePtr ListStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "List") { type = std::make_shared(); @@ -410,12 +410,12 @@ TypePtr ListStrToType(const std::string& type_name) { std::string element_strs = type_name.substr(start, end - start); std::vector element_types = StringToVectorOfType(element_strs); bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); if (wrong) { return nullptr; } type = std::make_shared(element_types); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -423,7 +423,7 @@ TypePtr ListStrToType(const std::string& type_name) { return type; } -TypePtr TupleStrToType(const std::string& type_name) { +TypePtr TupleStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Tuple") { type = std::make_shared(); @@ -437,19 +437,19 @@ TypePtr TupleStrToType(const std::string& type_name) { std::string element_strs = type_name.substr(start, end - start); std::vector element_types = StringToVectorOfType(element_strs); bool wrong = - std::any_of(element_types.begin(), element_types.end(), [](const TypePtr& x) { return x == nullptr; }); + std::any_of(element_types.begin(), element_types.end(), [](const TypePtr &x) { return x == nullptr; }); if (wrong) { return nullptr; } type = std::make_shared(element_types); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } return type; } -TypePtr FunctionStrToType(const std::string& type_name) { +TypePtr FunctionStrToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name == "Function") { @@ -478,12 +478,12 @@ TypePtr FunctionStrToType(const std::string& type_name) { std::vector args_type = StringToVectorOfType(str_args); TypePtr retval = StringToType(str_retval); - bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr& x) { return x == nullptr; }); + bool wrong = std::any_of(args_type.begin(), args_type.end(), [](const TypePtr &x) { return x == nullptr; }); if (retval == nullptr || wrong) { return nullptr; } type = std::make_shared(args_type, retval); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error " << e.what(); } } @@ -491,7 +491,7 @@ TypePtr FunctionStrToType(const std::string& type_name) { } } // namespace -TypePtr StringToType(const std::string& type_name) { +TypePtr StringToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name.compare("None") == 0) { type = std::make_shared(); @@ -542,7 +542,7 @@ TypePtr StringToType(const std::string& type_name) { return type; } -bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) { if (x == nullptr || base_type == nullptr) { MS_LOG(ERROR) << "Type is nullptr."; return false; @@ -564,7 +564,7 @@ bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type) { } } -bool IsSubType(TypePtr const& t1, TypePtr const& t2) { +bool IsSubType(TypePtr const &t1, TypePtr const &t2) { MS_EXCEPTION_IF_NULL(t1); if (t1->type_id() == kTypeUnknown) { return false; @@ -576,17 +576,17 @@ bool IsSubType(TypePtr const& t1, TypePtr const& t2) { } REGISTER_PYBIND_DEFINE( - typing, ([](py::module* const m) { + typing, ([](py::module *const m) { auto m_sub = m->def_submodule("typing", "submodule for dtype"); py::enum_(m_sub, "TypeId"); (void)m_sub.def("is_subclass", &IsIdentidityOrSubclass, "is equal or subclass"); (void)m_sub.def("load_type", &TypeIdToType, "load type"); (void)m_sub.def( - "dump_type", [](const TypePtr& t) { return t->type_id(); }, "dump type"); + "dump_type", [](const TypePtr &t) { return t->type_id(); }, "dump type"); (void)py::class_>(m_sub, "Type") .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) .def("__eq__", - [](const TypePtr& t1, const TypePtr& t2) { + [](const TypePtr &t1, const TypePtr &t2) { if (t1 != nullptr && t2 != nullptr) { return *t1 == *t2; } @@ -595,7 +595,7 @@ REGISTER_PYBIND_DEFINE( .def("__hash__", &Type::hash) .def("__str__", &Type::ToString) .def("__repr__", &Type::ReprString) - .def("__deepcopy__", [](const TypePtr& t, py::dict) { + .def("__deepcopy__", [](const TypePtr &t, py::dict) { if (t == nullptr) { return static_cast(nullptr); } @@ -605,21 +605,21 @@ REGISTER_PYBIND_DEFINE( (void)py::class_>(m_sub, "Bool") .def(py::init()) .def(py::pickle( - [](const Bool&) { // __getstate__ + [](const Bool &) { // __getstate__ return py::make_tuple(); }, - [](const py::tuple&) { // __setstate__ + [](const py::tuple &) { // __setstate__ return std::make_shared(); })); (void)py::class_>(m_sub, "Int") .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const Int& t) { // __getstate__ + [](const Int &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -631,11 +631,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const UInt& t) { // __getstate__ + [](const UInt &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -647,11 +647,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init()) .def(py::init(), py::arg("nbits")) .def(py::pickle( - [](const Float& t) { // __getstate__ + [](const Float &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(t.nbits())); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } @@ -670,11 +670,11 @@ REGISTER_PYBIND_DEFINE( .def(py::init(), py::arg("element")) .def("element_type", &TensorType::element) .def(py::pickle( - [](const TensorType& t) { // __getstate__ + [](const TensorType &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(py::int_(static_cast(t.element()->type_id()))); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } diff --git a/mindspore/ccsrc/ir/dtype.h b/mindspore/ccsrc/ir/dtype.h index e3e2099b5e..cefdf42099 100644 --- a/mindspore/ccsrc/ir/dtype.h +++ b/mindspore/ccsrc/ir/dtype.h @@ -60,7 +60,7 @@ using StringPtr = std::shared_ptr; class Keyword : public Object { public: Keyword() : Object(kObjectTypeKeyword, false), key_(""), value_(nullptr) {} - Keyword(const std::string& key, const TypePtr& value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} + Keyword(const std::string &key, const TypePtr &value) : Object(kObjectTypeKeyword, false), key_(key), value_(value) {} ~Keyword() override = default; MS_DECLARE_PARENT(Keyword, Object) @@ -70,7 +70,7 @@ class Keyword : public Object { std::string ToString() const override; std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::string GetKey() const { return key_; } TypePtr GetValue() const { return value_; } @@ -84,7 +84,7 @@ using KeywordPtr = std::shared_ptr; class Slice : public Object { public: Slice() : Object(kObjectTypeSlice), start_(nullptr), stop_(nullptr), step_(nullptr) {} - Slice(const TypePtr& start, const TypePtr& stop, const TypePtr& step) + Slice(const TypePtr &start, const TypePtr &stop, const TypePtr &step) : Object(kObjectTypeSlice, false), start_(start), stop_(stop), step_(step) {} ~Slice() override = default; @@ -95,7 +95,7 @@ class Slice : public Object { std::string ToString() const override; std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr get_start() const { return start_; } TypePtr get_stop() const { return stop_; } @@ -111,19 +111,19 @@ using SlicePtr = std::shared_ptr; class TensorType : public Object { public: TensorType() : Object(kObjectTypeTensorType) {} - explicit TensorType(const TypePtr& ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} + explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {} ~TensorType() override = default; MS_DECLARE_PARENT(TensorType, Object) TypeId generic_type_id() const override { return kObjectTypeTensorType; } const TypePtr element() const { return element_type_; } - void set_element(const TypePtr& element_type) { element_type_ = element_type; } + void set_element(const TypePtr &element_type) { element_type_ = element_type; } TypePtr DeepCopy() const override; std::string ToString() const override; std::string ToReprString() const override { return "tensor"; } std::string DumpText() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; private: TypePtr element_type_; @@ -133,7 +133,7 @@ using TensorTypePtr = std::shared_ptr; class Function : public Object { public: Function(); - Function(const std::vector& args, const TypePtr retval); + Function(const std::vector &args, const TypePtr retval); ~Function() override = default; MS_DECLARE_PARENT(Function, Object) @@ -141,11 +141,11 @@ class Function : public Object { // Add temporarily for return abstraction to avoid type checking. bool IsTransparent() const { return (args_.empty()) && (retval_ == nullptr); } - const std::vector& args() const { return args_; } - const TypePtr& retval() const { return retval_; } + const std::vector &args() const { return args_; } + const TypePtr &retval() const { return retval_; } TypePtr DeepCopy() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::string ToString() const override; std::string ToReprString() const override { return "function"; } @@ -158,7 +158,7 @@ using FunctionPtr = std::shared_ptr; class JTagged : public Object { public: JTagged() : Object(kObjectTypeJTagged) {} - explicit JTagged(const TypePtr& subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} + explicit JTagged(const TypePtr &subtype) : Object(kObjectTypeJTagged, false), subtype_(subtype) {} ~JTagged() override = default; MS_DECLARE_PARENT(JTagged, Object) @@ -213,7 +213,7 @@ using TypeTypePtr = std::shared_ptr; class Problem : public Type { public: Problem() : Type(kMetaTypeProblem), kind_(Named("unknown")) {} - explicit Problem(const Named& kind) : Type(kMetaTypeProblem), kind_(kind) {} + explicit Problem(const Named &kind) : Type(kMetaTypeProblem), kind_(kind) {} ~Problem() override = default; MS_DECLARE_PARENT(Problem, Type) @@ -222,7 +222,7 @@ class Problem : public Type { std::string ToString() const override { return kind_.name(); } std::string DumpText() const override { return "ProblemType"; } - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr problem); + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr problem); private: Named kind_; @@ -246,29 +246,29 @@ using ExternalPtr = std::shared_ptr; // helper template template -TypePtr Clone(const T& t) { +TypePtr Clone(const T &t) { return t.Clone(); } -TypePtr StringToType(const std::string& type_name); +TypePtr StringToType(const std::string &type_name); // Judge whether x is predicate or is a subclass of predicate. -bool IsIdentidityOrSubclass(TypePtr const& x, TypePtr const& base_type); +bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type); // Whether t1 is identity or a subclass of t2. -bool IsSubType(TypePtr const& t1, TypePtr const& t2 = nullptr); +bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr); struct TypeHasher { - std::size_t operator()(TypePtr const& type) const; + std::size_t operator()(TypePtr const &type) const; }; struct TypeListHasher { - std::size_t operator()(const TypePtrList& type_list) const; + std::size_t operator()(const TypePtrList &type_list) const; }; struct TypeEqual { - bool operator()(TypePtr const& t1, TypePtr const& t2) const; + bool operator()(TypePtr const &t1, TypePtr const &t2) const; }; struct TypeListEqual { - bool operator()(TypePtrList const& lhs, TypePtrList const& rhs) const; + bool operator()(TypePtrList const &lhs, TypePtrList const &rhs) const; }; extern const TypePtr kTypeExternal; diff --git a/mindspore/ccsrc/ir/dtype/container.cc b/mindspore/ccsrc/ir/dtype/container.cc index 8bca29f793..3f8244c2e3 100644 --- a/mindspore/ccsrc/ir/dtype/container.cc +++ b/mindspore/ccsrc/ir/dtype/container.cc @@ -24,7 +24,7 @@ #include "pybind_api/export_flags.h" namespace mindspore { -static std::string DumpTypeVector(const std::vector& elements, bool is_dumptext) { +static std::string DumpTypeVector(const std::vector &elements, bool is_dumptext) { std::ostringstream oss; bool begin = true; int cnt = 0; @@ -65,7 +65,7 @@ TypePtr List::DeepCopy() const { } else { TypePtrList elements; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), - [](const TypePtr& ele) { return ele->DeepCopy(); }); + [](const TypePtr &ele) { return ele->DeepCopy(); }); auto copy = std::make_shared(elements); return copy; } @@ -78,11 +78,11 @@ const TypePtr List::operator[](std::size_t dim) const { return elements_[dim]; } -bool List::operator==(const Type& other) const { +bool List::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const List& other_list = static_cast(other); + const List &other_list = static_cast(other); if (elements_.size() != other_list.elements_.size()) { return false; } @@ -94,8 +94,8 @@ bool List::operator==(const Type& other) const { return true; } -Class::Class(const Named& tag, const ClassAttrVector& attributes, - const std::unordered_map& methods) +Class::Class(const Named &tag, const ClassAttrVector &attributes, + const std::unordered_map &methods) : Object(kObjectTypeClass, false), attributes_(attributes), tag_(tag), methods_(methods) {} std::string List::ToString() const { @@ -122,7 +122,7 @@ std::string List::DumpText() const { return buffer.str(); } -bool Class::operator==(const Type& other) const { +bool Class::operator==(const Type &other) const { // Class is cached for each pyobj in ParseDataClass, so ClassPtr is one by one map to pyobj. return &other == this; } @@ -143,7 +143,7 @@ std::string Class::ToString() const { } else { bool begin = true; buffer << "cls." << tag_ << "["; - for (auto& attr : attributes_) { + for (auto &attr : attributes_) { if (!begin) { buffer << ", "; } else { @@ -163,7 +163,7 @@ std::string Class::DumpText() const { } else { bool begin = true; buffer << "Cls." << tag_ << "["; - for (auto& attr : attributes_) { + for (auto &attr : attributes_) { if (!begin) { buffer << ", "; } else { @@ -182,17 +182,17 @@ TypePtr Tuple::DeepCopy() const { } else { TypePtrList elements; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(elements), - [](const TypePtr& ele) { return ele->DeepCopy(); }); + [](const TypePtr &ele) { return ele->DeepCopy(); }); auto copy = std::make_shared(elements); return copy; } } -bool Tuple::operator==(const Type& other) const { +bool Tuple::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_tuple = static_cast(other); + auto other_tuple = static_cast(other); if (elements_.size() != other_tuple.elements_.size()) { return false; } @@ -242,7 +242,7 @@ TypePtr Dictionary::DeepCopy() const { std::vector> kv; (void)std::transform( key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair& item) { return std::make_pair(item.first, item.second->DeepCopy()); }); + [](const std::pair &item) { return std::make_pair(item.first, item.second->DeepCopy()); }); return std::make_shared(kv); } } @@ -259,7 +259,7 @@ std::string Dictionary::ToString() const { std::ostringstream buffer; std::vector keys; std::vector values; - for (const auto& kv : key_values_) { + for (const auto &kv : key_values_) { keys.push_back(kv.first); values.push_back(kv.second); } @@ -276,12 +276,12 @@ std::string Dictionary::ToString() const { std::string Dictionary::DumpText() const { return ToString(); } -bool Dictionary::operator==(const mindspore::Type& other) const { +bool Dictionary::operator==(const mindspore::Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - const auto& other_dict = static_cast(other); + const auto &other_dict = static_cast(other); if (key_values_.size() != other_dict.key_values_.size()) { return false; } diff --git a/mindspore/ccsrc/ir/dtype/container.h b/mindspore/ccsrc/ir/dtype/container.h index 04ed484cf7..0612d24c4d 100644 --- a/mindspore/ccsrc/ir/dtype/container.h +++ b/mindspore/ccsrc/ir/dtype/container.h @@ -40,10 +40,10 @@ namespace mindspore { class List : public Object { public: List() : Object(kObjectTypeList) {} - List(const std::initializer_list& objs) + List(const std::initializer_list &objs) : Object(kObjectTypeList, false), elements_(objs.begin(), objs.end()) {} // Shadow copy; - explicit List(const TypePtrList& obj) : Object(kObjectTypeList, false), elements_(obj) {} + explicit List(const TypePtrList &obj) : Object(kObjectTypeList, false), elements_(obj) {} ~List() override {} MS_DECLARE_PARENT(List, Object) @@ -51,7 +51,7 @@ class List : public Object { TypeId generic_type_id() const override { return kObjectTypeList; } TypePtr DeepCopy() const override; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; std::size_t size() const { return elements_.size(); } TypePtrList elements() const { return elements_; } std::string ToString() const override; @@ -68,22 +68,22 @@ using ClassAttrVector = std::vector>; class Class : public Object { public: Class() : Object(kObjectTypeClass), tag_(Named("Class")) {} - Class(const Named& tag, const ClassAttrVector& attributes, const std::unordered_map& methods); + Class(const Named &tag, const ClassAttrVector &attributes, const std::unordered_map &methods); ~Class() override {} MS_DECLARE_PARENT(Class, Object) TypeId generic_type_id() const override { return kObjectTypeClass; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override; std::string ToString() const override; std::string DumpText() const override; - void set_value(const std::unordered_map& v) { attributes_value_ = v; } + void set_value(const std::unordered_map &v) { attributes_value_ = v; } Named tag() { return tag_; } std::unordered_map GetValue() { return attributes_value_; } std::unordered_map methods() { return methods_; } - ClassAttrVector& GetAttributes() { return attributes_; } + ClassAttrVector &GetAttributes() { return attributes_; } ClassAttrVector attributes_; @@ -99,11 +99,11 @@ class Tuple : public Object { public: Tuple() : Object(kObjectTypeTuple) {} // usage : Tuple t = {std::make_shared(), std::make_shared(32)}; - Tuple(const std::initializer_list& objs) + Tuple(const std::initializer_list &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} // Shadow copy - explicit Tuple(const TypePtrList& objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} + explicit Tuple(const TypePtrList &objs) : Object(kObjectTypeTuple, false), elements_(objs.begin(), objs.end()) {} ~Tuple() override {} MS_DECLARE_PARENT(Tuple, Object) @@ -115,7 +115,7 @@ class Tuple : public Object { std::string ToReprString() const override { return "tuple_"; } std::string DumpText() const override; const TypePtr operator[](size_t dim) const; - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtrList elements() const { return elements_; } std::size_t size() const { return elements_.size(); } @@ -128,7 +128,7 @@ using TuplePtr = std::shared_ptr; class Dictionary : public Object { public: Dictionary() : Object(kObjectTypeDictionary) {} - explicit Dictionary(const std::vector>& key_values) + explicit Dictionary(const std::vector> &key_values) : Object(kObjectTypeDictionary, false), key_values_(key_values) {} ~Dictionary() override {} @@ -136,7 +136,7 @@ class Dictionary : public Object { TypeId generic_type_id() const override { return kObjectTypeDictionary; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override; std::string ToString() const override; std::string DumpText() const override; diff --git a/mindspore/ccsrc/ir/dtype/number.cc b/mindspore/ccsrc/ir/dtype/number.cc index d9ef6bb3bd..44ac9e8e6a 100644 --- a/mindspore/ccsrc/ir/dtype/number.cc +++ b/mindspore/ccsrc/ir/dtype/number.cc @@ -24,11 +24,11 @@ #include "pybind_api/export_flags.h" namespace mindspore { -bool Number::operator==(const Type& other) const { +bool Number::operator==(const Type &other) const { if (!IsSameObjectType(*this, other)) { return false; } - auto other_number = static_cast(other); + auto other_number = static_cast(other); return ((number_type_ == other_number.number_type_) && (nbits_ == other_number.nbits_)); } diff --git a/mindspore/ccsrc/ir/dtype/number.h b/mindspore/ccsrc/ir/dtype/number.h index cb3b0a607c..3930f51d73 100644 --- a/mindspore/ccsrc/ir/dtype/number.h +++ b/mindspore/ccsrc/ir/dtype/number.h @@ -49,12 +49,12 @@ class Number : public Object { TypeId type_id() const override { return number_type_; } TypeId generic_type_id() const override { return kObjectTypeNumber; } - bool operator==(const Type& other) const override; + bool operator==(const Type &other) const override; TypePtr DeepCopy() const override { return std::make_shared(); } std::string ToString() const override { return "Number"; } std::string ToReprString() const override { return "number"; } std::string DumpText() const override { return "Number"; } - std::string GetTypeName(const std::string& type_name) const { + std::string GetTypeName(const std::string &type_name) const { std::ostringstream oss; oss << type_name; if (nbits() != 0) { diff --git a/mindspore/ccsrc/ir/dtype/ref.h b/mindspore/ccsrc/ir/dtype/ref.h index 7f1dc4a95f..7d8159289f 100644 --- a/mindspore/ccsrc/ir/dtype/ref.h +++ b/mindspore/ccsrc/ir/dtype/ref.h @@ -51,7 +51,7 @@ class RefKeyType : public Object { class RefType : public Object { public: RefType() : Object(kObjectTypeRef) {} - RefType(const TypePtr& subtype, const TypePtr& subtype_origin) + RefType(const TypePtr &subtype, const TypePtr &subtype_origin) : Object(kObjectTypeRef, false), subtype_(subtype), subtype_origin_(subtype_origin) {} ~RefType() override {} MS_DECLARE_PARENT(RefType, Object) diff --git a/mindspore/ccsrc/ir/dtype/type.cc b/mindspore/ccsrc/ir/dtype/type.cc index 6fbd7f8111..30bf0c8e3f 100644 --- a/mindspore/ccsrc/ir/dtype/type.cc +++ b/mindspore/ccsrc/ir/dtype/type.cc @@ -69,7 +69,7 @@ TypeId FloatBitsToTypeId(const int nbits) { } } -const char* MetaIdLabel(const TypeId& v) { +const char *MetaIdLabel(const TypeId &v) { switch (v) { case kTypeUnknown: return "kTypeUnknown"; @@ -92,7 +92,7 @@ const char* MetaIdLabel(const TypeId& v) { } } -const char* ObjectIdLabel(const TypeId& v) { +const char *ObjectIdLabel(const TypeId &v) { switch (v) { case kObjectTypeNumber: return "kObjectTypeNumber"; @@ -129,7 +129,7 @@ const char* ObjectIdLabel(const TypeId& v) { } } -const char* NumberIdLabel(const TypeId& v) { +const char *NumberIdLabel(const TypeId &v) { switch (v) { case kNumberTypeBool: return "kNumberTypeBool"; @@ -166,7 +166,7 @@ const char* NumberIdLabel(const TypeId& v) { } } -const char* TypeIdLabel(const TypeId& v) { +const char *TypeIdLabel(const TypeId &v) { if (v < kMetaTypeEnd) { return MetaIdLabel(v); } else { @@ -190,14 +190,14 @@ TypeId NormalizeTypeId(const TypeId type_id) { } } -bool IsSameObjectType(const Type& lhs, const Type& rhs) { +bool IsSameObjectType(const Type &lhs, const Type &rhs) { if ((lhs.meta_type() != kMetaTypeObject) || (rhs.meta_type() != kMetaTypeObject)) { return false; } return lhs.object_type() == rhs.object_type(); } -size_t GetTypeByte(const TypePtr& type_ptr) { +size_t GetTypeByte(const TypePtr &type_ptr) { if (type_ptr && type_ptr->isa()) { auto number = dyn_cast(type_ptr); if (!number) { @@ -212,9 +212,9 @@ size_t GetTypeByte(const TypePtr& type_ptr) { } } -bool Type::operator==(const Value& other) const { +bool Type::operator==(const Value &other) const { if (other.isa()) { - auto other_type = static_cast(&other); + auto other_type = static_cast(&other); return *this == *other_type; } else { return false; @@ -226,12 +226,12 @@ abstract::AbstractBasePtr Type::ToAbstract() { return ptr; } -std::ostream& operator<<(std::ostream& os, const Type& type) { +std::ostream &operator<<(std::ostream &os, const Type &type) { os << type.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const TypePtr type) { +std::ostream &operator<<(std::ostream &os, const TypePtr type) { os << type->ToString(); return os; } @@ -244,17 +244,17 @@ bool Object::equal(const TypePtr other) const { return false; } -std::ostream& operator<<(std::ostream& os, const Object& obj) { +std::ostream &operator<<(std::ostream &os, const Object &obj) { os << obj.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr obj) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj) { os << obj->ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const TypePtrList& types) { +std::ostream &operator<<(std::ostream &os, const TypePtrList &types) { os << "["; for (size_t i = 0; i < types.size(); ++i) { if (i > 0) { diff --git a/mindspore/ccsrc/ir/dtype/type.h b/mindspore/ccsrc/ir/dtype/type.h index 9454596538..0528bccf03 100644 --- a/mindspore/ccsrc/ir/dtype/type.h +++ b/mindspore/ccsrc/ir/dtype/type.h @@ -95,10 +95,10 @@ enum TypeId : int { TypeId IntBitsToTypeId(const int nbits); TypeId UIntBitsToTypeId(const int nbits); TypeId FloatBitsToTypeId(const int nbits); -const char* TypeIdLabel(const TypeId& v); +const char *TypeIdLabel(const TypeId &v); TypeId NormalizeTypeId(const TypeId type_id); -bool IsSameObjectType(const Type& lhs, const Type& rhs); -size_t GetTypeByte(const TypePtr& type_ptr); +bool IsSameObjectType(const Type &lhs, const Type &rhs); +size_t GetTypeByte(const TypePtr &type_ptr); // Base class for all types // forward declaration. @@ -110,14 +110,14 @@ class Type : public Value { ~Type() override = default; MS_DECLARE_PARENT(Type, Value) - bool operator==(const Value& other) const override; + bool operator==(const Value &other) const override; TypeId meta_type() const { return meta_type_; } virtual TypeId type_id() const { return meta_type_; } virtual TypeId generic_type_id() const { return kMetaTypeType; } - virtual bool operator!=(const Type& other) const { return !(*this == other); } - virtual bool operator==(const Type& other) const { return this->type_id() == other.type_id(); } + virtual bool operator!=(const Type &other) const { return !(*this == other); } + virtual bool operator==(const Type &other) const { return this->type_id() == other.type_id(); } virtual bool equal(const TypePtr other) const { return *this == *other; } virtual TypeId object_type() const { return kTypeUnknown; } @@ -134,8 +134,8 @@ class Type : public Value { bool IsUnknown() const { return (meta_type_ == kMetaTypeType); } bool IsGeneric() const { return is_generic_; } abstract::AbstractBasePtr ToAbstract() override; - friend std::ostream& operator<<(std::ostream& os, const Type& type); - friend std::ostream& operator<<(std::ostream& os, const TypePtr type); + friend std::ostream &operator<<(std::ostream &os, const Type &type); + friend std::ostream &operator<<(std::ostream &os, const TypePtr type); const bool parse_info_ = true; @@ -163,14 +163,14 @@ class Object : public Type { bool equal(const TypePtr other) const override; std::string ToString() const override { return std::string("Object:") + TypeIdLabel(object_type_); } - friend std::ostream& operator<<(std::ostream& os, const Object& obj); - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr obj); + friend std::ostream &operator<<(std::ostream &os, const Object &obj); + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr obj); private: const TypeId object_type_; }; -std::ostream& operator<<(std::ostream& os, const TypePtrList& types); +std::ostream &operator<<(std::ostream &os, const TypePtrList &types); } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_DTYPE_TYPE_H_ diff --git a/mindspore/ccsrc/ir/func_graph.cc b/mindspore/ccsrc/ir/func_graph.cc index 93fd9c0936..8a58f320f1 100644 --- a/mindspore/ccsrc/ir/func_graph.cc +++ b/mindspore/ccsrc/ir/func_graph.cc @@ -61,7 +61,7 @@ FuncGraph::FuncGraph() AbstractFunctionPtr FuncGraph::abstract() { AbstractBasePtrList args_spec_list; - for (auto& p : parameters_) { + for (auto &p : parameters_) { MS_EXCEPTION_IF_NULL(p); if (p->abstract() == nullptr) { MS_LOG(ERROR) << "Error!!"; @@ -78,7 +78,7 @@ AbstractFunctionPtr FuncGraph::abstract() { return std::make_shared(args_spec_list, output()->abstract()); } -abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr& context) { +abstract::AbstractBasePtr FuncGraph::MakeAbstractClosure(const abstract::AnalysisContextPtr &context) { AnalysisContextPtr temp_context = context; if (temp_context == nullptr) { temp_context = abstract::AnalysisContext::DummyContext(); @@ -96,7 +96,7 @@ AnfNodePtr FuncGraph::output() const { } } -void FuncGraph::set_output(const AnfNodePtr& value, bool force_new_ret) { +void FuncGraph::set_output(const AnfNodePtr &value, bool force_new_ret) { if (force_new_ret || return_ == nullptr) { std::vector params({NewValueNode(prim::kPrimReturn), value}); FuncGraphPtr this_graph = shared_from_base(); @@ -125,7 +125,7 @@ ParameterPtr FuncGraph::add_parameter() { return p; } -void FuncGraph::add_parameter(const ParameterPtr& p) { +void FuncGraph::add_parameter(const ParameterPtr &p) { if (manager_.lock()) { std::vector new_params = parameters_; new_params.push_back(p); @@ -135,7 +135,7 @@ void FuncGraph::add_parameter(const ParameterPtr& p) { } } -ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { +ParameterPtr FuncGraph::AddWeightParameter(const std::string &name) { FuncGraphPtr this_graph = shared_from_base(); ParameterPtr p = std::make_shared(this_graph); p->set_name(name); @@ -154,14 +154,14 @@ ParameterPtr FuncGraph::AddWeightParameter(const std::string& name) { return p; } -bool FuncGraph::has_flag(const std::string& flag) { +bool FuncGraph::has_flag(const std::string &flag) { if (flags_.count(flag)) { return flags_[flag]; } return false; } -CNodePtr FuncGraph::NewCNode(const std::vector& inputs) { +CNodePtr FuncGraph::NewCNode(const std::vector &inputs) { CNodePtr cnode = std::make_shared(inputs, shared_from_base()); if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { order_.push_back(cnode); @@ -170,7 +170,7 @@ CNodePtr FuncGraph::NewCNode(const std::vector& inputs) { return cnode; } -CNodePtr FuncGraph::NewCNodeWithScope(const std::vector& inputs, const ScopePtr& scope) { +CNodePtr FuncGraph::NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope) { CNodePtr app = NewCNode(inputs); app->set_scope(scope); return app; @@ -178,13 +178,13 @@ CNodePtr FuncGraph::NewCNodeWithScope(const std::vector& inputs, con void FuncGraph::DumpCNodeList() { MS_LOG(INFO) << "FuncGraph " << ToString() << " has following CNode in code order:"; - for (const auto& cnode : order_) { + for (const auto &cnode : order_) { MS_LOG(INFO) << cnode->DebugString(); } } std::string FuncGraph::ToString() const { - return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); + return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } GraphDebugInfoPtr FuncGraph::debug_info() { @@ -195,38 +195,38 @@ GraphDebugInfoPtr FuncGraph::debug_info() { return this->debug_info_; } -const AnfNodeSet& FuncGraph::nodes() { +const AnfNodeSet &FuncGraph::nodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& nodes = mng->nodes(); + auto &nodes = mng->nodes(); return nodes[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::value_nodes() { +const AnfNodeCounterMap &FuncGraph::value_nodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& cts = mng->valuenodes(); + auto &cts = mng->valuenodes(); return cts[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::free_variables_direct() { +const AnfNodeCounterMap &FuncGraph::free_variables_direct() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& fv_direct = mng->free_variables_direct(); + auto &fv_direct = mng->free_variables_direct(); return fv_direct[shared_from_base()]; } -const BaseRefCounterMap& FuncGraph::free_variables_total() { +const BaseRefCounterMap &FuncGraph::free_variables_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& fv_total = mng->free_variables_total(); + auto &fv_total = mng->free_variables_total(); return fv_total[shared_from_base()]; } std::vector FuncGraph::free_variables_nodes() { std::vector nodes; - const auto& fv_total = this->free_variables_total(); - for (auto& p : fv_total) { + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { nodes.push_back(utils::cast(key)); @@ -238,8 +238,8 @@ std::vector FuncGraph::free_variables_nodes() { std::vector FuncGraph::free_variables_func_graphs() { std::vector func_graphs; - const auto& fv_total = this->free_variables_total(); - for (auto& p : fv_total) { + const auto &fv_total = this->free_variables_total(); + for (auto &p : fv_total) { auto key = p.first; if (utils::isa(key)) { func_graphs.push_back(utils::cast(key)); @@ -249,31 +249,31 @@ std::vector FuncGraph::free_variables_func_graphs() { return func_graphs; } -const FuncGraphCounterMap& FuncGraph::func_graphs_used() { +const FuncGraphCounterMap &FuncGraph::func_graphs_used() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& used = mng->func_graphs_used(); + auto &used = mng->func_graphs_used(); return used[shared_from_base()]; } -const FuncGraphSet& FuncGraph::func_graphs_used_total() { +const FuncGraphSet &FuncGraph::func_graphs_used_total() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& used = mng->func_graphs_used_total(shared_from_base()); + auto &used = mng->func_graphs_used_total(shared_from_base()); return used; } -const FuncGraphCounterMap& FuncGraph::func_graph_users() { +const FuncGraphCounterMap &FuncGraph::func_graph_users() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& users = mng->func_graph_users(); + auto &users = mng->func_graph_users(); return users[shared_from_base()]; } -const AnfNodeCounterMap& FuncGraph::func_graph_user_cnodes() { +const AnfNodeCounterMap &FuncGraph::func_graph_user_cnodes() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); - auto& users = mng->func_graph_user_cnodes(); + auto &users = mng->func_graph_user_cnodes(); return users[shared_from_base()]; } @@ -288,13 +288,13 @@ FuncGraphPtr FuncGraph::parent() { return mng->parent(shared_from_base()); } -const FuncGraphSet& FuncGraph::children() { +const FuncGraphSet &FuncGraph::children() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->children(shared_from_base()); } -const FuncGraphSet& FuncGraph::scope() { +const FuncGraphSet &FuncGraph::scope() { auto mng = manager_.lock(); MS_EXCEPTION_IF_NULL(mng); return mng->scopes(shared_from_base()); @@ -312,9 +312,9 @@ std::shared_ptr> FuncGraph::recursive_graphs() { return mng->recursive_graphs(shared_from_base()); } -void FuncGraph::DumpFuncGraph(const std::string& path) { draw::Draw(path + ".dot", shared_from_base()); } +void FuncGraph::DumpFuncGraph(const std::string &path) { draw::Draw(path + ".dot", shared_from_base()); } -AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { +AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string &name) { auto itr = this->parameter_default_value_.find(name); if (itr == parameter_default_value_.end()) { return nullptr; @@ -330,9 +330,9 @@ AnfNodePtr FuncGraph::GetDefaultValueByName(const std::string& name) { } // set the default values -void FuncGraph::SetDefaultValues(const std::vector& name_list, const std::vector& value_list) { +void FuncGraph::SetDefaultValues(const std::vector &name_list, const std::vector &value_list) { auto all_is_null = std::all_of(value_list.begin(), value_list.end(), - [](const AnfNodePtr& node) { return IsValueNode(node); }); + [](const AnfNodePtr &node) { return IsValueNode(node); }); if (value_list.empty()) { all_is_null = true; } @@ -348,7 +348,7 @@ void FuncGraph::ClearDefaultValues() { parameter_default_value_.clear(); } size_t FuncGraph::GetDefaultValueCount() { int null_count = std::count_if(parameter_default_value_.begin(), parameter_default_value_.end(), - [](const std::pair& pair) { return IsValueNode(pair.second); }); + [](const std::pair &pair) { return IsValueNode(pair.second); }); return parameter_default_value_.size() - IntToSize(null_count); } @@ -425,7 +425,7 @@ int FuncGraph::GetPositionalArgsCount() const { return count - kwonlyargs_count_ - SizeToInt(hyper_param_count_); } -AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { +AnfNodePtr FuncGraph::GetParameterByName(const std::string &name) { for (size_t i = 0; i < parameters_.size(); ++i) { MS_EXCEPTION_IF_NULL(parameters_[i]); auto param_cast = parameters_[i]->cast(); @@ -437,9 +437,9 @@ AnfNodePtr FuncGraph::GetParameterByName(const std::string& name) { return nullptr; } -void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, - std::vector* specialized_parameter_list, - std::unordered_map* repl_nodes, int variable_args_count, +void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + std::unordered_map *repl_nodes, int variable_args_count, int pos_args_input_count) { // if there is variable argument, pass the input arguments that does not match positional args to it as a tuple if (specialized_graph->has_vararg()) { @@ -472,14 +472,14 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr& specialized_graph, } } -void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, - std::vector* specialized_parameter_list, - const std::vector& kwarg_list, - std::unordered_map* repl_nodes) { +void FuncGraph::GenerateKwParams(const FuncGraphPtr &specialized_graph, + std::vector *specialized_parameter_list, + const std::vector &kwarg_list, + std::unordered_map *repl_nodes) { std::vector kwarg_keys_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; std::vector kwarg_values_tuple_nodes = {NewValueNode(prim::kPrimMakeTuple)}; - for (const auto& kwarg : kwarg_list) { + for (const auto &kwarg : kwarg_list) { MS_EXCEPTION_IF_NULL(kwarg); std::string kw_param_name = kwarg->get_key(); MS_EXCEPTION_IF_NULL(specialized_graph); @@ -493,7 +493,7 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, std::string param_name = specialized_graph->GetVariableKwargName() + "[" + kw_param_name + "]"; MS_EXCEPTION_IF_NULL(specialized_parameter_list); auto find_kw_arg_in_list = std::any_of(specialized_parameter_list->begin(), specialized_parameter_list->end(), - [param_name](const AnfNodePtr& node) { + [param_name](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto param = node->cast(); return param != nullptr && param->name() == param_name; @@ -526,10 +526,10 @@ void FuncGraph::GenerateKwParams(const FuncGraphPtr& specialized_graph, GenerateKwargReplNode(specialized_graph, repl_nodes, kwarg_keys_tuple_nodes, kwarg_values_tuple_nodes); } -void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, - std::unordered_map* repl_nodes, - const std::vector& kwarg_keys_tuple_nodes, - const std::vector& kwarg_values_tuple_nodes) { +void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr &specialized_graph, + std::unordered_map *repl_nodes, + const std::vector &kwarg_keys_tuple_nodes, + const std::vector &kwarg_values_tuple_nodes) { if (has_kwarg()) { MS_EXCEPTION_IF_NULL(specialized_graph); TraceManager::DebugTrace( @@ -544,7 +544,7 @@ void FuncGraph::GenerateKwargReplNode(const FuncGraphPtr& specialized_graph, } } -bool FuncGraph::NeedGenerate(const std::vector& kwarg_list) { +bool FuncGraph::NeedGenerate(const std::vector &kwarg_list) { // if the function does not have any vararg/kwarg/kwonly/default value/kw args input // return the original graph if (!has_vararg() && kwonlyargs_count() == 0 && !has_kwarg() && GetDefaultValueCount() == 0 && kwarg_list.empty()) { @@ -558,9 +558,9 @@ bool FuncGraph::NeedGenerate(const std::vector& return true; } -void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, - const std::vector& specialized_parameter_list, - std::unordered_map* repl_nodes) { +void FuncGraph::GenerateDefaultValue(const FuncGraphPtr &specialized_graph, + const std::vector &specialized_parameter_list, + std::unordered_map *repl_nodes) { MS_EXCEPTION_IF_NULL(specialized_graph); for (size_t i = 0; i < specialized_graph->parameters().size() - hyper_param_count(); ++i) { auto param_node = specialized_graph->parameters()[i]; @@ -583,10 +583,10 @@ void FuncGraph::GenerateDefaultValue(const FuncGraphPtr& specialized_graph, } } -FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList &args_spec_list) { std::vector kwarg_list; size_t arguments_count = args_spec_list.size(); - for (const auto& arg : args_spec_list) { + for (const auto &arg : args_spec_list) { // if it is a keyword argument MS_EXCEPTION_IF_NULL(arg); if (arg->isa()) { @@ -619,11 +619,11 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) MS_EXCEPTION_IF_NULL(specialized_graph); auto params = specialized_graph->parameters(); (void)std::transform(params.end() - SizeToInt(hyper_param_count()), params.end(), - std::back_inserter(specialized_parameter_list), [](const AnfNodePtr& node) { return node; }); + std::back_inserter(specialized_parameter_list), [](const AnfNodePtr &node) { return node; }); std::shared_ptr manager = mindspore::Manage(specialized_graph, false); auto tr = manager->Transact(); - for (auto& node_pair : repl_nodes) { + for (auto &node_pair : repl_nodes) { MS_LOG(DEBUG) << "GenerateGraph replace:" << node_pair.first->DebugString() << "-" << node_pair.second->DebugString(); (void)tr.Replace(node_pair.first, node_pair.second); @@ -638,7 +638,7 @@ FuncGraphPtr FuncGraph::GenerateGraph(const AbstractBasePtrList& args_spec_list) return specialized_graph; } -void FuncGraph::add_parameter_obj_node(const AnfNodePtr& p) { paramter_obj_nodes_.push_back(p); } +void FuncGraph::add_parameter_obj_node(const AnfNodePtr &p) { paramter_obj_nodes_.push_back(p); } std::list FuncGraph::GetOrderedCnodes() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { @@ -651,7 +651,7 @@ std::list FuncGraph::GetOrderedCnodes() { std::list cnodes; auto nodes = TopoSort(get_return(), SuccDepends, BelongSameGraph); - for (const auto& node : nodes) { + for (const auto &node : nodes) { auto cnode = dyn_cast(node); if (cnode) { cnodes.push_back(cnode); @@ -679,7 +679,7 @@ void FuncGraph::EraseUnusedNodeInOrder() { } } -void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr& n) { +void FuncGraph::EraseUnusedNodeInOrder(const AnfNodePtr &n) { if (has_flag(GRAPH_FLAG_HAS_EFFECT) && n && n->isa()) { order_.remove(n->cast()); MS_LOG(DEBUG) << "Remove the node" << n->DebugString() << " from order list."; @@ -690,7 +690,7 @@ void FuncGraph::CheckOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { MS_LOG(DEBUG) << "Check graph " << ToString(); for (auto it = order_.begin(); it != order_.end(); (void)it++) { - for (const auto& input_node : (*it)->inputs()) { + for (const auto &input_node : (*it)->inputs()) { if (input_node && input_node->isa() && input_node->func_graph() == shared_from_base()) { // Need to reorder the wrong order node. auto found = std::find(order_.begin(), it, input_node); @@ -705,7 +705,7 @@ void FuncGraph::CheckOrder() { } auto mng = manager_.lock(); if (mng != nullptr) { - const auto& nodes = mng->nodes()[shared_from_base()]; + const auto &nodes = mng->nodes()[shared_from_base()]; if (nodes.size() != (order_.size() + parameters_.size())) { DumpCNodeList(); MS_LOG(EXCEPTION) << "CNode order size " << order_.size() << " is not equal to managed node size " @@ -718,7 +718,7 @@ void FuncGraph::CheckOrder() { const char kPrimHasEffect[] = "_side_effect_flag"; -bool FuncGraph::HasEffect(const CNodePtr& cnode) { +bool FuncGraph::HasEffect(const CNodePtr &cnode) { auto prim = GetCNodePrimitive(cnode); if (prim != nullptr && prim->isa()) { auto do_sig = prim->cast(); @@ -739,9 +739,9 @@ bool FuncGraph::HasEffect(const CNodePtr& cnode) { return false; } -std::shared_ptr> FindRoots(const std::vector& segment) { +std::shared_ptr> FindRoots(const std::vector &segment) { std::shared_ptr> roots = std::make_shared>(segment); - for (const auto& node : segment) { + for (const auto &node : segment) { if (roots->size() == 1) { return roots; } @@ -757,9 +757,9 @@ std::shared_ptr> FindRoots(const std::vector& seg return roots; } -std::shared_ptr> FindLeaves(const std::vector& segment) { +std::shared_ptr> FindLeaves(const std::vector &segment) { std::shared_ptr> nodes = std::make_shared>(segment); - for (const auto& node : segment) { + for (const auto &node : segment) { if (nodes->size() == 1) { return nodes; } @@ -790,7 +790,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { if (has_flag(GRAPH_FLAG_HAS_EFFECT)) { std::list depends_order; std::vector segment; - for (const auto& cnode : order_) { + for (const auto &cnode : order_) { if (IsPrimitiveCNode(cnode, prim::kPrimReturn)) { continue; } @@ -830,7 +830,7 @@ void FuncGraph::ReleaseFullOrderToEffectOrder() { } } -void FuncGraph::SetEffectDepends(const std::vector& depend_inputs) { +void FuncGraph::SetEffectDepends(const std::vector &depend_inputs) { auto old_ret = output(); std::vector inputs{NewValueNode(prim::kPrimDepend), old_ret}; (void)inputs.insert(inputs.end(), depend_inputs.begin(), depend_inputs.end()); diff --git a/mindspore/ccsrc/ir/func_graph_cloner.cc b/mindspore/ccsrc/ir/func_graph_cloner.cc index d90cdbacf2..c086b8d7d1 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.cc +++ b/mindspore/ccsrc/ir/func_graph_cloner.cc @@ -26,29 +26,29 @@ // namespace to support intermediate representation definition namespace mindspore { -Cloner::Cloner(const FuncGraphPtrList& func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, - bool clone_all_used_graphs, const TraceInfoPtr& relation, const TraceInfoPtr& target_relation) +Cloner::Cloner(const FuncGraphPtrList &func_graphs, bool clone_all_valuenodes, bool clone_all_child_graphs, + bool clone_all_used_graphs, const TraceInfoPtr &relation, const TraceInfoPtr &target_relation) : clone_all_valuenodes_(clone_all_valuenodes), clone_all_child_graphs_(clone_all_child_graphs), clone_all_used_graphs_(clone_all_used_graphs), relation_(relation), target_relation_(target_relation == nullptr ? relation : target_relation) { - for (auto& func_graph : func_graphs) { + for (auto &func_graph : func_graphs) { AddClone(func_graph); } scope_ = kDefaultScope; type_ = kBasic; } -void Cloner::AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& params, CloneType type) { +void Cloner::AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList ¶ms, CloneType type) { if (func_graph != nullptr) { todo_.push_back({.origin = func_graph, .target = target_func_graph, .params = params}); type_ = type; } } -void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); if (repl_node_.find(node) != repl_node_.end() || node->isa()) { return; @@ -60,7 +60,7 @@ void Cloner::CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target) { } } -void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add) { +void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -77,7 +77,7 @@ void Cloner::CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, TraceManager::EndTrace(); } -void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -91,7 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target) { TraceManager::EndTrace(); } -void Cloner::CloneValueNode(const AnfNodePtr& node) { +void Cloner::CloneValueNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TraceManager::DebugTrace(node->debug_info(), relation_); ValueNodePtr new_const = NewValueNode(GetValueNode(node)); @@ -102,7 +102,7 @@ void Cloner::CloneValueNode(const AnfNodePtr& node) { TraceManager::EndTrace(); } -void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) { +void Cloner::CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(target); TraceManager::DebugTrace(node->debug_info(), relation_); @@ -114,14 +114,14 @@ void Cloner::CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target) TraceManager::EndTrace(); } -void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { +void Cloner::CloneValueNodes(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_valuenodes_) { return; } - auto& value_nodes = manager_->valuenodes()[func_graph]; - for (auto& value_node : value_nodes) { + auto &value_nodes = manager_->valuenodes()[func_graph]; + for (auto &value_node : value_nodes) { auto old_node = value_node.first; MS_EXCEPTION_IF_NULL(old_node); if (repl_node_.count(old_node) == 0) { @@ -130,38 +130,38 @@ void Cloner::CloneValueNodes(const FuncGraphPtr& func_graph) { } } -void Cloner::AddChildGraphs(const FuncGraphPtr& func_graph) { +void Cloner::AddChildGraphs(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_child_graphs_) { return; } - auto& scopes = manager_->scopes(func_graph); - for (auto& graph : scopes) { + auto &scopes = manager_->scopes(func_graph); + for (auto &graph : scopes) { if (graph != func_graph) { todo_.push_back({graph, nullptr, {}}); } } } -void Cloner::AddTotalGraphs(const FuncGraphPtr& func_graph) { +void Cloner::AddTotalGraphs(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(manager_); if (!clone_all_used_graphs_) { return; } - auto& used_graphs = manager_->func_graphs_used()[func_graph]; - for (auto& used_graph : used_graphs) { + auto &used_graphs = manager_->func_graphs_used()[func_graph]; + for (auto &used_graph : used_graphs) { todo_.push_back({used_graph.first, nullptr, {}}); } } -void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); - for (auto& item : func_graph->parameter_default_value()) { + for (auto &item : func_graph->parameter_default_value()) { auto nodes = DeepLinkedGraphSearch(item.second); - for (auto& node : nodes) { + for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { CloneNode(node, target_func_graph); @@ -172,7 +172,7 @@ void Cloner::CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const F } } -void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); @@ -182,15 +182,15 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const Func } target_func_graph->set_return(return_node); - auto& value_nodes = manager_->func_graph_valuenodes()[func_graph]; - for (auto& value_node : value_nodes) { + auto &value_nodes = manager_->func_graph_valuenodes()[func_graph]; + for (auto &value_node : value_nodes) { CloneValueNode(value_node.first, target_func_graph); } } -void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params) { +void Cloner::InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms) { MS_EXCEPTION_IF_NULL(func_graph); - auto& old_params = func_graph->parameters(); + auto &old_params = func_graph->parameters(); if (old_params.size() != params.size()) { MS_LOG(EXCEPTION) << "Origin params size[" << old_params.size() << "], inline params size[" << params.size() << "]"; return; @@ -200,7 +200,7 @@ void Cloner::InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNode } } -void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph) { +void Cloner::SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); TraceManager::DebugTrace(func_graph->debug_info(), target_relation_); @@ -215,33 +215,33 @@ void Cloner::SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* cons TraceManager::EndTrace(); } -void Cloner::CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); - auto& params = func_graph->parameters(); - for (auto& param : params) { + auto ¶ms = func_graph->parameters(); + for (auto ¶m : params) { CloneParameter(param, target_func_graph, true); } repl_func_graph_[func_graph] = target_func_graph; } -void Cloner::GenParameters(const FuncGraphPtr& func_graph) { +void Cloner::GenParameters(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - auto& free_vars = manager_->free_variables_total(); + auto &free_vars = manager_->free_variables_total(); auto iter = free_vars.find(func_graph); if (iter == free_vars.end()) { return; } - for (auto& fv_map : iter->second) { - auto& free_var = fv_map.first; + for (auto &fv_map : iter->second) { + auto &free_var = fv_map.first; if (utils::isa(free_var)) { repl_func_graph_params_[func_graph].push_back(AddParameter(func_graph, utils::cast(free_var))); } } } -void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { +void Cloner::CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node) { param->set_abstract(node->abstract()); if (node->isa()) { ParameterPtr old_param = dyn_cast(node); @@ -252,7 +252,7 @@ void Cloner::CloneParameter(const ParameterPtr& param, const AnfNodePtr& node) { } } -ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add) { +ParameterPtr Cloner::AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add) { TraceManager::DebugTrace(std::make_shared(node->debug_info())); ParameterPtr param = std::make_shared(func_graph); TraceManager::EndTrace(); @@ -265,11 +265,11 @@ ParameterPtr Cloner::AddParameter(const FuncGraphPtr& func_graph, const AnfNodeP return param; } -void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, - AnfNodePtrList* const lift_params, AnfNodePtrList* const input_params) { +void Cloner::AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, + AnfNodePtrList *const lift_params, AnfNodePtrList *const input_params) { AnfNodePtrList parameters; std::unordered_set old_params; - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { auto iter = repl_node_.find(param); if (iter != repl_node_.end()) { (void)old_params.insert(iter->second); @@ -280,7 +280,7 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& } } AnfNodePtr new_param = nullptr; - for (auto& param : params) { + for (auto ¶m : params) { auto old_param = repl_node_[param]; if (old_param->isa() && old_param->func_graph() == func_graph) { repl_node_[old_param] = old_param; @@ -301,10 +301,10 @@ void Cloner::AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& func_graph->set_parameters(parameters); } -void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params) { +void Cloner::AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { AnfNodePtr node = nullptr; - auto& repl_func_graph = repl_map_func_graph_[func_graph_user]; + auto &repl_func_graph = repl_map_func_graph_[func_graph_user]; auto iter = repl_func_graph.find(func_graph); if (iter == repl_func_graph.end()) { node = func_graph_user->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(func_graph)}); @@ -322,9 +322,9 @@ void Cloner::AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& OrderParameters(func_graph, inputs); } -void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs) { +void Cloner::OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs) { std::unordered_set old_params; - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { (void)old_params.insert(repl_node_[param]); } std::unordered_set new_params; @@ -339,7 +339,7 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis (void)new_params.insert(new_param); } } - for (auto& param : func_graph->parameters()) { + for (auto ¶m : func_graph->parameters()) { if (new_params.find(param) == new_params.end()) { parameters.push_back(param); } @@ -347,9 +347,9 @@ void Cloner::OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrLis func_graph->set_parameters(parameters); } -void Cloner::SetEdges(const FuncGraphPtr& func_graph) { +void Cloner::SetEdges(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - for (auto& node : func_graph->nodes()) { + for (auto &node : func_graph->nodes()) { if (node == nullptr) { continue; } @@ -358,17 +358,17 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { continue; } auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); for (size_t i = 0; i < inputs.size(); i++) { - auto& input = inputs[i]; + auto &input = inputs[i]; if (IsValueNode(input)) { auto graph = GetValueNode(input); - auto& repl_func_graph = repl_map_func_graph_[func_graph]; + auto &repl_func_graph = repl_map_func_graph_[func_graph]; if (repl_func_graph.find(graph) != repl_func_graph.end()) { transaction_.SetEdge(cnode, SizeToInt(i), repl_func_graph[graph]); } } else { - auto& repl_node = repl_map_node_[func_graph]; + auto &repl_node = repl_map_node_[func_graph]; if (repl_node.find(input) != repl_node.end()) { transaction_.SetEdge(cnode, SizeToInt(i), repl_node[input]); } @@ -377,8 +377,8 @@ void Cloner::SetEdges(const FuncGraphPtr& func_graph) { } } -void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params) { +void Cloner::LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms) { AnfNodePtrList lift_params; AnfNodePtrList input_params; AddParameters(func_graph_user, params, &lift_params, &input_params); @@ -386,16 +386,16 @@ void Cloner::LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraph if (lift_params.empty()) { return; } - for (auto& user : func_graph_user->func_graph_users()) { + for (auto &user : func_graph_user->func_graph_users()) { LiftParameters(user.first, func_graph_user, lift_params); } } void Cloner::Lift() { - for (auto& func_graph_params : repl_func_graph_params_) { - auto& func_graph = func_graph_params.first; - auto& params = func_graph_params.second; - for (auto& user : func_graph->func_graph_users()) { + for (auto &func_graph_params : repl_func_graph_params_) { + auto &func_graph = func_graph_params.first; + auto ¶ms = func_graph_params.second; + for (auto &user : func_graph->func_graph_users()) { LiftParameters(user.first, func_graph, params); } } @@ -404,18 +404,18 @@ void Cloner::Lift() { void Cloner::LiftParameters() { MS_EXCEPTION_IF_NULL(manager_); transaction_ = manager_->Transact(); - const FuncGraphSet& func_graphs = manager_->func_graphs(); - for (auto& func_graph : func_graphs) { + const FuncGraphSet &func_graphs = manager_->func_graphs(); + for (auto &func_graph : func_graphs) { GenParameters(func_graph); } Lift(); - for (auto& func_graph : func_graphs) { + for (auto &func_graph : func_graphs) { SetEdges(func_graph); } transaction_.Commit(); } -bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { +bool Cloner::CheckStatus(const FuncGraphPtr &func_graph, bool is_inline) { MS_EXCEPTION_IF_NULL(func_graph); // Make sure only inline once if (status_.count(func_graph) != 0) { @@ -430,12 +430,12 @@ bool Cloner::CheckStatus(const FuncGraphPtr& func_graph, bool is_inline) { return true; } -void Cloner::CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph) { +void Cloner::CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); MS_EXCEPTION_IF_NULL(manager_); - const AnfNodeSet& nodes = manager_->nodes()[func_graph]; - for (auto& node : nodes) { + const AnfNodeSet &nodes = manager_->nodes()[func_graph]; + for (auto &node : nodes) { CloneNode(node, target_func_graph); } } @@ -449,7 +449,7 @@ void Cloner::Run() { // Basic and Inline Clone FuncGraphPtrList func_graphs; (void)std::transform(todo_.begin(), todo_.end(), std::back_inserter(func_graphs), - [](const CloneInfo& item) -> FuncGraphPtr { return item.origin; }); + [](const CloneInfo &item) -> FuncGraphPtr { return item.origin; }); manager_ = Manage(func_graphs, false); CloneNodes(); LinkEdges(); @@ -495,13 +495,13 @@ void Cloner::CloneNodes() { } void Cloner::LinkEdges() { - for (auto& node_pair : nodes_) { + for (auto &node_pair : nodes_) { CNodePtr old_node = node_pair.first; CNodePtr new_node = node_pair.second; MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(new_node); - for (auto& input : old_node->inputs()) { - auto& new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; + for (auto &input : old_node->inputs()) { + auto &new_input = (repl_node_.count(input) == 0) ? input : repl_node_[input]; new_node->add_input(new_input); } } @@ -509,10 +509,10 @@ void Cloner::LinkEdges() { // For the graphs cloned, update its default value map to the cloned nodes void Cloner::SetDefaults() { - for (auto& item : graph_set_) { + for (auto &item : graph_set_) { MS_EXCEPTION_IF_NULL(item); if (repl_func_graph_.count(item) != 0) { - for (auto& param_def : item->parameter_default_value()) { + for (auto ¶m_def : item->parameter_default_value()) { MS_EXCEPTION_IF_NULL(repl_func_graph_[item]); if (repl_node_.count(param_def.second) != 0) { repl_func_graph_[item]->set_param_default_value(param_def.first, repl_node_[param_def.second]); @@ -524,7 +524,7 @@ void Cloner::SetDefaults() { } } -AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { +AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr &root) { MS_EXCEPTION_IF_NULL(root); if (repl_func_graph_.find(root->func_graph()) == repl_func_graph_.end()) { MS_LOG(EXCEPTION) << "Cannot find func graph " << root->func_graph()->ToString() << " in cloner."; @@ -537,7 +537,7 @@ AnfNodePtr Cloner::CloneDisconnected(const AnfNodePtr& root) { MS_LOG(EXCEPTION) << "Failed in clone for node " << root->DebugString() << "."; } -AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { +AnfNodePtr Cloner::operator[](const AnfNodePtr &node) { #ifdef ENABLE_PROFILE double time = GetTime(); #endif @@ -548,7 +548,7 @@ AnfNodePtr Cloner::operator[](const AnfNodePtr& node) { return ((repl_node_.count(node) == 0) ? node : repl_node_[node]); } -FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { +FuncGraphPtr Cloner::operator[](const FuncGraphPtr &func_graph) { #ifdef ENABLE_PROFILE double time = GetTime(); #endif @@ -559,14 +559,14 @@ FuncGraphPtr Cloner::operator[](const FuncGraphPtr& func_graph) { return ((repl_func_graph_.count(func_graph) == 0) ? func_graph : repl_func_graph_[func_graph]); } -FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph) { +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); Cloner cloner({func_graph}, false, true, true, std::make_shared(), nullptr); return cloner[func_graph]; } -AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& func_graph_args, const ScopePtr& scope) { +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(target_func_graph); Cloner cloner({}, false); @@ -577,14 +577,14 @@ AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& targe return cloner[func_graph->output()]; } -FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph) { +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); Cloner cloner({}, false); cloner.AddClone(func_graph, nullptr, {}, kLifting); return cloner[func_graph]; } -ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { MS_EXCEPTION_IF_NULL(func_graph); FuncGraphPtrList func_graphs = {func_graph}; ClonerPtr cloner = @@ -599,14 +599,14 @@ ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& r return cloner; } -FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation) { +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation) { MS_EXCEPTION_IF_NULL(func_graph); TraceManager::DebugTrace(func_graph->debug_info(), relation); auto new_func_graph = std::make_shared(); TraceManager::EndTrace(); - auto& parameters = func_graph->parameters(); - (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr& param) -> void { + auto ¶meters = func_graph->parameters(); + (void)std::for_each(parameters.begin(), parameters.end(), [&new_func_graph](const AnfNodePtr ¶m) -> void { MS_EXCEPTION_IF_NULL(param); TraceManager::DebugTrace(std::make_shared(param->debug_info())); (void)new_func_graph->add_parameter(); @@ -622,7 +622,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, const TraceInfoP new_func_graph->set_kwonlyargs_count(func_graph->kwonlyargs_count()); new_func_graph->set_hyper_param_count(func_graph->hyper_param_count()); new_func_graph->set_is_generate(func_graph->is_generated()); - for (auto& item : func_graph->parameter_default_value()) { + for (auto &item : func_graph->parameter_default_value()) { new_func_graph->set_param_default_value(item.first, cloner[item.second]); } diff --git a/mindspore/ccsrc/ir/func_graph_cloner.h b/mindspore/ccsrc/ir/func_graph_cloner.h index dd228cf79f..426cf447a3 100644 --- a/mindspore/ccsrc/ir/func_graph_cloner.h +++ b/mindspore/ccsrc/ir/func_graph_cloner.h @@ -43,26 +43,26 @@ struct CloneInfo { class Cloner { public: - explicit Cloner(const FuncGraphPtrList& func_graphs = {}, bool clone_all_valuenodes = false, + explicit Cloner(const FuncGraphPtrList &func_graphs = {}, bool clone_all_valuenodes = false, bool clone_all_child_graphs = true, bool clone_all_used_graphs = false, - const TraceInfoPtr& relation = std::make_shared(), - const TraceInfoPtr& target_relation = nullptr); + const TraceInfoPtr &relation = std::make_shared(), + const TraceInfoPtr &target_relation = nullptr); ~Cloner() = default; - void AddClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph = nullptr, - const AnfNodePtrList& params = {}, CloneType type = kBasic); + void AddClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph = nullptr, + const AnfNodePtrList ¶ms = {}, CloneType type = kBasic); void Run(); // Interfaces for specializer - AnfNodePtr CloneDisconnected(const AnfNodePtr& root); - AnfNodePtr operator[](const AnfNodePtr& node); - FuncGraphPtr operator[](const FuncGraphPtr& func_graph); + AnfNodePtr CloneDisconnected(const AnfNodePtr &root); + AnfNodePtr operator[](const AnfNodePtr &node); + FuncGraphPtr operator[](const FuncGraphPtr &func_graph); // Map of replicate nodes and graphs - std::unordered_map* cloned_node() { return &repl_node_; } + std::unordered_map *cloned_node() { return &repl_node_; } std::unordered_map cloned_func_graph() { return repl_func_graph_; } // Scope of cloned graphs - void set_scope(const ScopePtr& scope) { scope_ = scope; } + void set_scope(const ScopePtr &scope) { scope_ = scope; } const ScopePtr scope() const { return scope_; } std::unordered_map repl_node_; @@ -71,31 +71,31 @@ class Cloner { void CloneNodes(); void LinkEdges(); void SetDefaults(); - void CloneNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneValueNode(const AnfNodePtr& node); - void CloneValueNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneCNode(const AnfNodePtr& node, const FuncGraphPtr& target); - void CloneParameter(const AnfNodePtr& node, const FuncGraphPtr& target, bool is_add = false); - void CloneValueNodes(const FuncGraphPtr& func_graph); - void AddChildGraphs(const FuncGraphPtr& func_graph); - void AddTotalGraphs(const FuncGraphPtr& func_graph); - bool CheckStatus(const FuncGraphPtr& func_graph, bool is_inline); - void CloneAllNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void CloneFuncGraphValueNodes(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void CloneFuncGraphDefaultValues(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void InlineCloneParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params); - void SetFuncGraphInfo(const FuncGraphPtr& func_graph, FuncGraphPtr* const target_func_graph); - void CloneParameters(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph); - void GenParameters(const FuncGraphPtr& func_graph); - void CloneParameter(const ParameterPtr& param, const AnfNodePtr& node); - ParameterPtr AddParameter(const FuncGraphPtr& func_graph, const AnfNodePtr& node, bool is_add = true); - void AddParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& params, AnfNodePtrList* const lift_params, - AnfNodePtrList* const input_params); - void AddInputs(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, const AnfNodePtrList& params); - void OrderParameters(const FuncGraphPtr& func_graph, const AnfNodePtrList& inputs); - void SetEdges(const FuncGraphPtr& func_graph); - void LiftParameters(const FuncGraphPtr& func_graph_user, const FuncGraphPtr& func_graph, - const AnfNodePtrList& params); + void CloneNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneValueNode(const AnfNodePtr &node); + void CloneValueNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target); + void CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target, bool is_add = false); + void CloneValueNodes(const FuncGraphPtr &func_graph); + void AddChildGraphs(const FuncGraphPtr &func_graph); + void AddTotalGraphs(const FuncGraphPtr &func_graph); + bool CheckStatus(const FuncGraphPtr &func_graph, bool is_inline); + void CloneAllNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void CloneFuncGraphDefaultValues(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void InlineCloneParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); + void SetFuncGraphInfo(const FuncGraphPtr &func_graph, FuncGraphPtr *const target_func_graph); + void CloneParameters(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph); + void GenParameters(const FuncGraphPtr &func_graph); + void CloneParameter(const ParameterPtr ¶m, const AnfNodePtr &node); + ParameterPtr AddParameter(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool is_add = true); + void AddParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms, AnfNodePtrList *const lift_params, + AnfNodePtrList *const input_params); + void AddInputs(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, const AnfNodePtrList ¶ms); + void OrderParameters(const FuncGraphPtr &func_graph, const AnfNodePtrList &inputs); + void SetEdges(const FuncGraphPtr &func_graph); + void LiftParameters(const FuncGraphPtr &func_graph_user, const FuncGraphPtr &func_graph, + const AnfNodePtrList ¶ms); void Lift(); void LiftParameters(); @@ -118,17 +118,17 @@ class Cloner { std::unordered_map repl_func_graph_params_; }; -FuncGraphPtr BasicClone(const FuncGraphPtr& func_graph); +FuncGraphPtr BasicClone(const FuncGraphPtr &func_graph); -AnfNodePtr InlineClone(const FuncGraphPtr& func_graph, const FuncGraphPtr& target_func_graph, - const AnfNodePtrList& func_graph_args, const ScopePtr& scope = nullptr); +AnfNodePtr InlineClone(const FuncGraphPtr &func_graph, const FuncGraphPtr &target_func_graph, + const AnfNodePtrList &func_graph_args, const ScopePtr &scope = nullptr); -FuncGraphPtr LiftingClone(const FuncGraphPtr& func_graph); +FuncGraphPtr LiftingClone(const FuncGraphPtr &func_graph); -ClonerPtr SpecializerClone(const FuncGraphPtr& func_graph, const TraceInfoPtr& relation); +ClonerPtr SpecializerClone(const FuncGraphPtr &func_graph, const TraceInfoPtr &relation); -FuncGraphPtr TransformableClone(const FuncGraphPtr& func_graph, - const TraceInfoPtr& relation = std::make_shared()); +FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, + const TraceInfoPtr &relation = std::make_shared()); } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_FUNC_GRAPH_CLONER_H_ diff --git a/mindspore/ccsrc/ir/manager.cc b/mindspore/ccsrc/ir/manager.cc index 889a091711..a53c9e95ae 100644 --- a/mindspore/ccsrc/ir/manager.cc +++ b/mindspore/ccsrc/ir/manager.cc @@ -27,17 +27,17 @@ namespace mindspore { -FuncGraphManagerPtr MakeManager(const std::vector& func_graphs, bool manage) { +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs, bool manage) { auto m = std::make_shared(func_graphs, manage); m->Init(); return m; } -FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool manage) { +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage) { FuncGraphManagerPtr m = nullptr; bool root = false; - for (auto& fg : func_graphs) { + for (auto &fg : func_graphs) { if (fg == nullptr) { continue; } @@ -53,7 +53,7 @@ FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool ma root = true; } - for (auto& fg : func_graphs) { + for (auto &fg : func_graphs) { if (fg == nullptr) { continue; } @@ -67,7 +67,7 @@ FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage) { return Manage(func_graphs, manage); } -FuncGraphManager::FuncGraphManager(const std::vector& roots, bool manage) +FuncGraphManager::FuncGraphManager(const std::vector &roots, bool manage) : roots_(roots), is_manage_(manage) { Reset(); } @@ -103,12 +103,12 @@ void FuncGraphManager::Init() { auto roots = roots_; roots_ = FuncGraphSet(); - for (auto& fg : roots) { + for (auto &fg : roots) { AddFuncGraph(fg, true); } } -FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::func_graph_parents_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_LOG(DEBUG) << "Start func_graph_parents_total func graph " << fg->ToString(); func_graph_parents_total_->Recompute(fg); @@ -116,7 +116,7 @@ FuncGraphSet& FuncGraphManager::func_graph_parents_total(const FuncGraphPtr& fg) return func_graph_parents_total_->func_graph_parents_total_analysis()[fg]; } -FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { +FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(func_graph_parent_); MS_LOG(DEBUG) << "Start parents func graph " << fg->ToString(); @@ -129,7 +129,7 @@ FuncGraphPtr FuncGraphManager::parent(const FuncGraphPtr& fg) const { return func_graph_parent_->parent_analysis()[fg]; } -FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::children(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(children_); MS_LOG(DEBUG) << "Start child func graph " << fg->ToString(); @@ -137,7 +137,7 @@ FuncGraphSet& FuncGraphManager::children(const FuncGraphPtr& fg) const { return children_->children_analysis()[fg]; } -FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::scopes(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); MS_EXCEPTION_IF_NULL(scopes_); MS_LOG(DEBUG) << "Start scopes func graph:" << fg->ToString(); @@ -146,19 +146,19 @@ FuncGraphSet& FuncGraphManager::scopes(const FuncGraphPtr& fg) const { return scopes_->scope_analysis()[fg]; } -FVTotalMap& FuncGraphManager::free_variables_total() const { +FVTotalMap &FuncGraphManager::free_variables_total() const { MS_EXCEPTION_IF_NULL(free_variables_total_); free_variables_total_->Recompute(); return free_variables_total_->fv_total_analysis(); } -FuncGraphSet& FuncGraphManager::func_graphs_used_total(const FuncGraphPtr& fg) const { +FuncGraphSet &FuncGraphManager::func_graphs_used_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(func_graphs_used_total_); func_graphs_used_total_->Recompute(fg); return func_graphs_used_total_->func_graph_used_total_analysis()[fg]; } -bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { +bool FuncGraphManager::recursive(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); recursive_->Recompute(fg); if (recursive_->recursive_analysis().count(fg) == 0) { @@ -168,7 +168,7 @@ bool FuncGraphManager::recursive(const FuncGraphPtr& fg) const { return recursive_->recursive_analysis()[fg]; } -std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr& fg) const { +std::shared_ptr> FuncGraphManager::recursive_graphs(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(fg); if (recursive(fg)) { if (!recursive_->recursive_map().count(fg)) { @@ -185,7 +185,7 @@ std::shared_ptr> FuncGraphManager::recursive_graphs(cons } } -bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr& fg) const { +bool FuncGraphManager::func_graph_j_total(const FuncGraphPtr &fg) const { MS_EXCEPTION_IF_NULL(j_total_); MS_EXCEPTION_IF_NULL(fg); j_total_->Recompute(fg); @@ -225,10 +225,10 @@ void FuncGraphManager::Clear() { signals_->InvalidateComputer(); } -void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { +void FuncGraphManager::KeepRoots(const std::vector &func_graphs) { MS_LOG(DEBUG) << "Start keep roots"; bool root_exist = false; - for (auto& item : func_graphs) { + for (auto &item : func_graphs) { if (roots_.contains(item)) { root_exist = true; break; @@ -245,17 +245,17 @@ void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { roots = roots_; } else { roots_.clear(); - for (auto& item : roots) { + for (auto &item : roots) { AddFuncGraph(item, true); } } FuncGraphSet keep; - for (auto& item : roots) { + for (auto &item : roots) { MS_LOG(DEBUG) << "roots: " << item->ToString(); keep.update(func_graphs_used_total(item)); #ifdef DEBUG - for (auto& k : keep) { + for (auto &k : keep) { MS_LOG(DEBUG) << "keep: " << k->ToString(); } #endif @@ -264,7 +264,7 @@ void FuncGraphManager::KeepRoots(const std::vector& func_graphs) { } else { Clear(); FuncGraphSet roots(func_graphs); - for (auto& item : roots) { + for (auto &item : roots) { AddFuncGraph(item, true); } } @@ -276,7 +276,7 @@ void FuncGraphManager::RemoveRoots() { MaybeDropFuncGraphs(func_graphs_, true); } -void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { +void FuncGraphManager::AddIntoManaged(const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(fg); if (is_manage_) { if (fg->manager() != nullptr && (&(*fg->manager()) != this)) { @@ -288,7 +288,7 @@ void FuncGraphManager::AddIntoManaged(const FuncGraphPtr& fg) { func_graphs_.add(fg); } -void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users) { +void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users) { FuncGraphSet todo(func_graphs); std::set dropped; // int count = 0; @@ -301,7 +301,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool continue; } MS_EXCEPTION_IF_NULL(func_graph_users_); - auto& users = func_graph_users_->count_func_graphs_map()[func_graph]; + auto &users = func_graph_users_->count_func_graphs_map()[func_graph]; if (!users.empty() && !ignore_users) { MS_LOG(DEBUG) << "Cannot drop as users not empty: " << func_graph->ToString(); continue; @@ -315,7 +315,7 @@ void FuncGraphManager::MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool todo.update(MaybeDropNodes(return_vec)); } MS_EXCEPTION_IF_NULL(signals_); - for (auto& fg : dropped) { + for (auto &fg : dropped) { MS_EXCEPTION_IF_NULL(fg); signals_->DropFuncGraph(fg); all_nodes_.difference_update(fg->parameters()); @@ -331,7 +331,7 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E MS_EXCEPTION_IF_NULL(inp); if (direction == kDecEdge) { MS_LOG(DEBUG) << "Remove node " << node->ToString() << " input[" << index << "] " << inp->ToString(); - auto& users_node = node_users_[inp]; + auto &users_node = node_users_[inp]; if (!users_node.contains(make_pair(node, index))) { return; } @@ -346,26 +346,26 @@ void FuncGraphManager::ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, E MS_LOG(DEBUG) << "Input[" << index << "] is const graph " << inp->ToString(); AddFuncGraph(GetValueNode(inp)); } - auto& users_node = node_users_[inp]; + auto &users_node = node_users_[inp]; users_node.add(make_pair(node, index)); MS_EXCEPTION_IF_NULL(signals_); signals_->AddEdge(node, index, inp); } } -void FuncGraphManager::ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction) { +void FuncGraphManager::ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); int index = 0; - for (auto& inp : cnode->inputs()) { + for (auto &inp : cnode->inputs()) { ProcessEdge(cnode, index, inp, direction); ++index; } } } -IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { +IncludeType FuncGraphManager::Limit(const AnfNodePtr &node) { if (all_nodes_.contains(node)) { return EXCLUDE; } else { @@ -373,9 +373,9 @@ IncludeType FuncGraphManager::Limit(const AnfNodePtr& node) { } } -void FuncGraphManager::AcquireNodes(const std::vector& nodes) { +void FuncGraphManager::AcquireNodes(const std::vector &nodes) { AnfNodeSet acq; - for (auto& node : nodes) { + for (auto &node : nodes) { std::function limit = std::bind(&FuncGraphManager::Limit, this, std::placeholders::_1); AnfNodeSet new_nodes = AnfNodeSet(DeepScopedGraphSearch(node, limit)); @@ -384,7 +384,7 @@ void FuncGraphManager::AcquireNodes(const std::vector& nodes) { acq.update(new_nodes); } - for (auto& node : acq) { + for (auto &node : acq) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr fg = node->func_graph(); if (fg != nullptr) { @@ -395,7 +395,7 @@ void FuncGraphManager::AcquireNodes(const std::vector& nodes) { } } -FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& nodes) { +FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector &nodes) { AnfNodeSet nodes_ordered(nodes); FuncGraphSetPtr func_graphs_to_check = std::make_shared(); MS_EXCEPTION_IF_NULL(signals_); @@ -406,7 +406,7 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& if (!all_nodes_.contains(node)) { continue; } - AnfNodeIndexSet& users = node_users_[node]; + AnfNodeIndexSet &users = node_users_[node]; std::vector parameters; if (!users.empty() || @@ -431,13 +431,13 @@ FuncGraphSetPtr FuncGraphManager::MaybeDropNodes(const std::vector& return func_graphs_to_check; } -void FuncGraphManager::SetParameters(const FuncGraphPtr& fg, const std::vector& parameters) { +void FuncGraphManager::SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters) { auto tr = Transact(); tr.SetParameters(fg, parameters); tr.Commit(); } -bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { +bool FuncGraphManager::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { auto tr = Transact(); bool success = tr.Replace(old_node, new_node); if (success) { @@ -446,13 +446,13 @@ bool FuncGraphManager::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new return success; } -void FuncGraphManager::SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value) { +void FuncGraphManager::SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value) { auto tr = Transact(); tr.SetEdge(node, index, value); tr.Commit(); } -void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope) { +void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope) { AnfNodePtr source_return = source->get_return(); AnfNodePtr source_output = source->output(); AnfNodePtr source_prim = source_return->cast()->input(0); @@ -466,23 +466,23 @@ void FuncGraphManager::MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr t (void)all_nodes_.erase(source_return); (void)node_users_.erase(source_return); signals_->DropNode(source_return); - for (auto& node : source->nodes()) { + for (auto &node : source->nodes()) { node->set_func_graph(target); if (node->scope() == kDefaultScope) { node->set_scope(scope); } } - for (auto& used : source->func_graphs_used()) { + for (auto &used : source->func_graphs_used()) { (void)func_graph_users_->Inc(used.first, target, used.second); (void)this->func_graph_users()[used.first].erase(source); } - for (auto& child : this->func_graph_child_direct()[source]) { + for (auto &child : this->func_graph_child_direct()[source]) { (void)func_graph_parents_direct_->Inc(child.first, target, child.second); (void)this->func_graph_parents_direct()[child.first].erase(source); } - for (auto& fv_count : this->free_variables_direct()[source]) { + for (auto &fv_count : this->free_variables_direct()[source]) { auto fv_g = fv_count.first->func_graph(); - auto& count_on_g = this->func_graph_child_direct()[fv_g]; + auto &count_on_g = this->func_graph_child_direct()[fv_g]; auto pair = count_on_g.find(source); if (fv_g != target && pair != count_on_g.end()) { (void)func_graph_child_direct_->Inc(fv_g, target, pair->second); @@ -504,9 +504,9 @@ FuncGraphTransaction FuncGraphManager::Transact() { return tr; } -void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupleCounter* add_edges, - EdgeTupleCounter* rm_edges, Counter* adds, Counter* rms) { - for (auto& iter : changes) { +void FuncGraphManager::ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, + EdgeTupleCounter *rm_edges, Counter *adds, Counter *rms) { + for (auto &iter : changes) { auto operation = iter.op; auto args = iter.args; if (operation == Change::kTxSetEdge) { @@ -521,10 +521,10 @@ void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupl auto param = args.cast(); MS_EXCEPTION_IF_NULL(param.func_graph); auto old_parameters = param.func_graph->parameters(); - for (auto& p : param.params) { + for (auto &p : param.params) { (*adds)[p] += 1; } - for (auto& p : old_parameters) { + for (auto &p : old_parameters) { (*rms)[p] += 1; } param.func_graph->set_parameters(param.params); @@ -532,7 +532,7 @@ void FuncGraphManager::ParseChanges(const std::vector& changes, EdgeTupl } } -void FuncGraphManager::CommitChanges(const std::vector& changes) { +void FuncGraphManager::CommitChanges(const std::vector &changes) { EdgeTupleCounter add_edges; EdgeTupleCounter rm_edges; Counter adds; @@ -540,7 +540,7 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { ParseChanges(changes, &add_edges, &rm_edges, &adds, &rms); auto sub_edges = add_edges - rm_edges; - for (auto& iter : sub_edges) { + for (auto &iter : sub_edges) { auto root_node = iter.first.first; int index = iter.first.second.first; auto new_node = iter.first.second.second; @@ -550,12 +550,12 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { auto sub_nodes = adds - rms; std::vector nodes; (void)std::transform(sub_nodes.begin(), sub_nodes.end(), std::back_inserter(nodes), - [](const std::pair& iter) -> AnfNodePtr { return iter.first; }); + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); AcquireNodes(nodes); auto sub_edges_reverse = rm_edges - add_edges; - for (auto& iter : sub_edges_reverse) { + for (auto &iter : sub_edges_reverse) { auto root_node = iter.first.first; int index = iter.first.second.first; auto old_node = iter.first.second.second; @@ -566,17 +566,17 @@ void FuncGraphManager::CommitChanges(const std::vector& changes) { std::vector nodes_reverse; (void)std::transform(sub_nodes_reverse.begin(), sub_nodes_reverse.end(), std::back_inserter(nodes_reverse), - [](const std::pair& iter) -> AnfNodePtr { return iter.first; }); + [](const std::pair &iter) -> AnfNodePtr { return iter.first; }); auto drop_func_graphs = MaybeDropNodes(nodes_reverse); MaybeDropFuncGraphs(*drop_func_graphs); } -void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector& params) { +void FuncGraphTransaction::SetParameters(FuncGraphPtr fg, const std::vector ¶ms) { changes_.emplace_back(Change::kTxSetParams, ArgsOfSetParams{fg, params}); } -bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node) { +bool FuncGraphTransaction::Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node) { MS_EXCEPTION_IF_NULL(old_node); MS_EXCEPTION_IF_NULL(new_node); FuncGraphPtr old_func_graph = old_node->func_graph(); @@ -585,14 +585,14 @@ bool FuncGraphTransaction::Replace(const AnfNodePtr& old_node, const AnfNodePtr& return false; } auto users = manager_->node_users()[old_node]; - for (auto& node : users) { + for (auto &node : users) { SetEdge(node.first, node.second, new_node); } return true; } -void FuncGraphTransaction::SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v) { +void FuncGraphTransaction::SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v) { if (k < 0) { MS_LOG(EXCEPTION) << "Invalid value k = " << k; } @@ -610,7 +610,7 @@ void FuncGraphTransaction::Commit() { manager_->CommitChanges(changes); } -FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) +FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager *const manager) : manager_(manager), include_func_graph_none_(false) { manager_->signals()->AddFuncGraph.connect(this, &FuncGraphAnalysis::OnAddFuncGraph); manager_->signals()->DropFuncGraph.connect(this, &FuncGraphAnalysis::OnDropFuncGraph); @@ -619,7 +619,7 @@ FuncGraphAnalysis::FuncGraphAnalysis(const FuncGraphManager* const manager) manager_->signals()->MoveAllCNode.connect(this, &FuncGraphAnalysis::OnMoveAllCNode); } -NodesCollector::NodesCollector(const FuncGraphManager* const m) : DepCollector(m), nodes_analysis_() { +NodesCollector::NodesCollector(const FuncGraphManager *const m) : DepCollector(m), nodes_analysis_() { include_func_graph_none_ = true; nodes_analysis_[nullptr] = AnfNodeSet(); @@ -646,7 +646,7 @@ void NodesCollector::OnDropNode(AnfNodePtr n) { void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // change the owner of node except for the src's return node - for (auto& it : nodes_analysis_[src]) { + for (auto &it : nodes_analysis_[src]) { nodes_analysis_[dst].add(it); } (void)nodes_analysis_.erase(src); @@ -654,15 +654,15 @@ void NodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { void DepCollector::OnAddEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kIncEdge); } -DepCollector::DepCollector(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { +DepCollector::DepCollector(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateCollector.connect(this, &DepCollector::OnInvalidateCollector); } void DepCollector::OnDropEdge(AnfNodePtr node, int index, AnfNodePtr inp) { OnModEdge(node, index, inp, kDecEdge); } -bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { - auto& d = count_nodes_map_[func_graph]; +bool CounterAnfNodeCollector::Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { + auto &d = count_nodes_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; return true; @@ -672,9 +672,9 @@ bool CounterAnfNodeCollector::Inc(const FuncGraphPtr& func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count = 1) { +bool CounterAnfNodeCollector::Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count = 1) { MS_EXCEPTION_IF_NULL(func_graph); - auto& d = count_nodes_map_[func_graph]; + auto &d = count_nodes_map_[func_graph]; if (d.count(key) != 0) { if (d[key] == count) { (void)d.erase(key); @@ -690,7 +690,7 @@ bool CounterAnfNodeCollector::Dec(const FuncGraphPtr& func_graph, const AnfNodeP return false; } -bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count) { +bool CounterAnfNodeCollector::Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { @@ -701,8 +701,8 @@ bool CounterAnfNodeCollector::Mod(const FuncGraphPtr& func_graph, const AnfNodeP } } -bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { - auto& d = count_func_graphs_map_[func_graph]; +bool CounterFuncGraphCollector::Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) == 0) { d[key] = count; return true; @@ -712,8 +712,8 @@ bool CounterFuncGraphCollector::Inc(const FuncGraphPtr& func_graph, const FuncGr return false; } -bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count = 1) { - auto& d = count_func_graphs_map_[func_graph]; +bool CounterFuncGraphCollector::Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count = 1) { + auto &d = count_func_graphs_map_[func_graph]; if (d.count(key) != 0) { if (d[key] == count) { (void)d.erase(key); @@ -729,7 +729,7 @@ bool CounterFuncGraphCollector::Dec(const FuncGraphPtr& func_graph, const FuncGr return false; } -bool CounterFuncGraphCollector::Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count) { +bool CounterFuncGraphCollector::Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count) { if (count > 0) { return Inc(func_graph, key, count); } else if (count < 0) { @@ -748,7 +748,7 @@ void ValueNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgePr } void ValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -762,7 +762,7 @@ void FuncGraphValueNodesCollector::OnModEdge(AnfNodePtr, int, AnfNodePtr inp, Ed } void FuncGraphValueNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -779,7 +779,7 @@ void FVDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeProc } void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { FuncGraphPtr fg2 = it.first->func_graph(); if (fg2 != dst) { (void)Inc(dst, it.first, it.second); @@ -788,7 +788,7 @@ void FVDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { (void)count_nodes_map_.erase(src); } -static FuncGraphPtr ParentProxy(const FuncGraphPtr& fg) { +static FuncGraphPtr ParentProxy(const FuncGraphPtr &fg) { FuncGraphPtr gn = std::make_shared(); (void)gn->transforms().insert(std::make_pair("proxy", FuncGraphTransform(fg))); return gn; @@ -805,7 +805,7 @@ void FuncGraphChildDirect::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, EdgeP } void FuncGraphChildDirect::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { FuncGraphPtr fg = it.first; if (fg != dst) { (void)Inc(dst, fg, it.second); @@ -835,7 +835,7 @@ void FuncGraphParentsDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr } void FuncGraphParentsDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { if (it.first != dst) { (void)Inc(dst, it.first, it.second); } @@ -852,7 +852,7 @@ void FuncGraphsUsedCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, Ed void FuncGraphsUsedCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // all graph use in src need to change to dst, so meger the to dst use - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_func_graphs_map_[dst].erase(src); @@ -879,7 +879,7 @@ void FuncGraphUserNodesCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp } void FuncGraphUserNodesCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { - for (auto& it : count_nodes_map_[src]) { + for (auto &it : count_nodes_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_nodes_map_.erase(src); @@ -895,13 +895,13 @@ void FuncGraphJDirectCollector::OnModEdge(AnfNodePtr node, int, AnfNodePtr inp, void FuncGraphJDirectCollector::OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) { // all graph use in src need to change to dst, so meger the to dst use - for (auto& it : count_func_graphs_map_[src]) { + for (auto &it : count_func_graphs_map_[src]) { (void)Inc(dst, it.first, it.second); } (void)count_func_graphs_map_.erase(src); } -DepComputer::DepComputer(const FuncGraphManager* const manager) : FuncGraphAnalysis(manager) { +DepComputer::DepComputer(const FuncGraphManager *const manager) : FuncGraphAnalysis(manager) { MS_EXCEPTION_IF_NULL(manager_); manager_->signals()->InvalidateComputer.connect(this, &DepComputer::OnInvalidateComputer); validate_ = false; @@ -914,20 +914,20 @@ void DepComputer::Recompute() { } } -void DepComputer::Recompute(const FuncGraphPtr& fg) { +void DepComputer::Recompute(const FuncGraphPtr &fg) { if (func_graphs_validate_.count(fg) == 0 || !func_graphs_validate_[fg]) { RealRecompute(fg); func_graphs_validate_[fg] = true; } } -FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path) { +FuncGraphSetPtr FuncGraphParentsTotalComputer::SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path) { if (path == nullptr || path->contains(fg)) { return std::make_shared(); } FuncGraphSetPtr parents = std::make_shared(); - FuncGraphToFuncGraphCounterMap& deps = *all_parents_direct_; - for (auto& dep : deps[fg]) { + FuncGraphToFuncGraphCounterMap &deps = *all_parents_direct_; + for (auto &dep : deps[fg]) { MS_EXCEPTION_IF_NULL(dep.first); auto proxy = dep.first->transforms().find("proxy"); if (proxy != dep.first->transforms().end()) { @@ -950,7 +950,7 @@ void FuncGraphParentsTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_LOG(DEBUG) << "FuncGraphParentsTotalComputer end: " << func_graph_parents_total_analysis_[fg].size(); } -bool set_len_compare(const FuncGraphSetPair& lhs, const FuncGraphSetPair& rhs) { +bool set_len_compare(const FuncGraphSetPair &lhs, const FuncGraphSetPair &rhs) { auto l1 = lhs.second.size(); auto l2 = rhs.second.size(); return l1 < l2; @@ -970,9 +970,9 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { } else { // return nearest parent as parent FuncGraphSet deps_copy(deps); - for (auto& dep : deps) { + for (auto &dep : deps) { auto parent_deps = this->manager_->func_graph_parents_total(dep); - for (auto& p_d : parent_deps) { + for (auto &p_d : parent_deps) { if (deps_copy.count(p_d)) { (void)deps_copy.erase(p_d); } @@ -988,7 +988,7 @@ void ParentComputer::RealRecompute(FuncGraphPtr fg) { void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); auto used_fg_total = manager_->func_graphs_used_total(fg); - for (auto& used_fg : used_fg_total) { + for (auto &used_fg : used_fg_total) { if (manager_->parent(used_fg) == fg) { children_analysis_[fg].add(used_fg); } @@ -997,11 +997,11 @@ void ChildrenComputer::RealRecompute(FuncGraphPtr fg) { void ScopeComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); - auto& children = manager_->children(fg); + auto &children = manager_->children(fg); scope_analysis_[fg] = FuncGraphSet(); scope_analysis_[fg].add(fg); - for (auto& child : children) { + for (auto &child : children) { scope_analysis_[fg].add(child); } } @@ -1010,20 +1010,20 @@ void FVTotalComputer::RealRecompute() { auto manager = DepComputer::manager_; MS_EXCEPTION_IF_NULL(manager); - for (auto& fg : manager->func_graphs()) { + for (auto &fg : manager->func_graphs()) { fv_total_analysis_[fg] = OrderedMap(); count_nodes_map_[fg] = OrderedMap(); count_func_graphs_map_[fg] = OrderedMap(); } - for (auto& fg : manager->func_graphs()) { + for (auto &fg : manager->func_graphs()) { AnfNodeCounterMap items = manager->free_variables_direct()[fg]; - for (auto& iter : items) { + for (auto &iter : items) { auto curr = fg; while (curr) { (void)CounterAnfNodeCollector::Mod(curr, iter.first, iter.second); curr = manager->parent(curr); - const AnfNodeSet& nodes = manager->nodes()[curr]; + const AnfNodeSet &nodes = manager->nodes()[curr]; if (nodes.contains(iter.first)) { break; } @@ -1031,7 +1031,7 @@ void FVTotalComputer::RealRecompute() { } auto items_fg = manager->func_graphs_used()[fg]; - for (auto& iter : items_fg) { + for (auto &iter : items_fg) { auto p = manager->parent(iter.first); if (p == nullptr) { continue; @@ -1043,13 +1043,13 @@ void FVTotalComputer::RealRecompute() { } } } - for (auto& fg : manager->func_graphs()) { - auto& fvp = count_nodes_map_[fg]; - auto& fvg = count_func_graphs_map_[fg]; - for (auto& item : fvp) { + for (auto &fg : manager->func_graphs()) { + auto &fvp = count_nodes_map_[fg]; + auto &fvg = count_func_graphs_map_[fg]; + for (auto &item : fvp) { fv_total_analysis_[fg][item.first] = item.second; } - for (auto& item : fvg) { + for (auto &item : fvg) { fv_total_analysis_[fg][item.first] = item.second; } } @@ -1057,15 +1057,15 @@ void FVTotalComputer::RealRecompute() { void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { MS_EXCEPTION_IF_NULL(manager_); - auto& used = this->manager_->func_graphs_used(); + auto &used = this->manager_->func_graphs_used(); std::vector todo; std::vector todo_new; todo.push_back(fg); while (!todo.empty()) { todo_new.clear(); - for (auto& gt : todo) { - for (auto& item : used[gt]) { + for (auto > : todo) { + for (auto &item : used[gt]) { auto used_fg = item.first; if (used_fg == fg) { func_graph_used_total_analysis_[fg].add(used_fg); @@ -1082,17 +1082,17 @@ void FuncGraphsUsedTotalComputer::RealRecompute(FuncGraphPtr fg) { } } -bool CheckRecursive(const FuncGraphManager* const manager, const FuncGraphPtr& fg) { +bool CheckRecursive(const FuncGraphManager *const manager, const FuncGraphPtr &fg) { MS_EXCEPTION_IF_NULL(manager); - auto& used = manager->func_graphs_used(); + auto &used = manager->func_graphs_used(); std::vector todo; std::vector todo_new; todo.push_back(fg); FuncGraphSet used_total; while (!todo.empty()) { todo_new.clear(); - for (auto& gt : todo) { - for (auto& item : used[gt]) { + for (auto > : todo) { + for (auto &item : used[gt]) { auto used_g = item.first; if (used_g == fg) { return true; @@ -1112,7 +1112,7 @@ void RecursiveComputer::RealRecompute(FuncGraphPtr fg) { this->recursive_analysis_[fg] = CheckRecursive(this->manager_, fg); } -void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::list* trace) { +void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace) { MS_EXCEPTION_IF_NULL(trace); auto res = std::find(trace->begin(), trace->end(), fg); // find recursive @@ -1124,7 +1124,7 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::listpush_back(fg); - auto& used_fgs = manager_->func_graphs_used()[fg]; + auto &used_fgs = manager_->func_graphs_used()[fg]; for (auto iter = used_fgs.begin(); iter != used_fgs.end(); (void)iter++) { CheckRecursiveGraphs(iter->first, trace); } @@ -1135,14 +1135,14 @@ void RecursiveComputer::CheckRecursiveGraphs(const FuncGraphPtr& fg, std::listcontains(fg)) { MS_LOG(DEBUG) << "" << fg->ToString() << " had been checked"; return false; } MS_EXCEPTION_IF_NULL(manager_); - auto& func_graph_counter_map = manager_->func_graph_j_direct(); + auto &func_graph_counter_map = manager_->func_graph_j_direct(); if (!func_graph_counter_map[fg].empty()) { // check g1->J(fg)->g2->g cycle; auto contains_j = @@ -1156,8 +1156,8 @@ bool FuncGraphJTotalComputer::SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPt path->add(fg); // check if func graphs used contains J(func_graph); - auto& used = this->manager_->func_graphs_used(); - for (auto& item : used[fg]) { + auto &used = this->manager_->func_graphs_used(); + for (auto &item : used[fg]) { auto used_g = item.first; if (SeekJ(used_g, path)) { MS_LOG(DEBUG) << "" << fg->ToString() << " users func graph " << used_g->ToString() diff --git a/mindspore/ccsrc/ir/manager.h b/mindspore/ccsrc/ir/manager.h index aaf5a0aa5f..54c1e8a692 100644 --- a/mindspore/ccsrc/ir/manager.h +++ b/mindspore/ccsrc/ir/manager.h @@ -46,13 +46,13 @@ class FuncGraphManager; using FuncGraphManagerPtr = std::shared_ptr; struct AnfNodeIndexPairHasher { - std::size_t operator()(const std::pair& p1) const { - return std::hash{}(p1.first.get()); + std::size_t operator()(const std::pair &p1) const { + return std::hash{}(p1.first.get()); } }; struct AnfNodeIndexPairEqual { - bool operator()(const std::pair& lhs, const std::pair& rhs) const { + bool operator()(const std::pair &lhs, const std::pair &rhs) const { return lhs == rhs; } }; @@ -63,14 +63,14 @@ using FuncGraphSetPair = std::pair; using FuncGraphSetPtr = std::shared_ptr; using EdgeTuple = std::pair>; struct EdgeTupleHasher { - std::size_t operator()(const EdgeTuple& p1) const { - return hash_combine({std::hash{}(p1.first.get()), std::hash{}(p1.second.first), - std::hash{}(p1.second.second.get())}); + std::size_t operator()(const EdgeTuple &p1) const { + return hash_combine({std::hash{}(p1.first.get()), std::hash{}(p1.second.first), + std::hash{}(p1.second.second.get())}); } }; struct EdgeTupleEqual { - bool operator()(const EdgeTuple& lhs, const EdgeTuple& rhs) const { + bool operator()(const EdgeTuple &lhs, const EdgeTuple &rhs) const { return lhs.first == rhs.first && lhs.second.first == rhs.second.first && lhs.second.second == rhs.second.second; } }; @@ -82,9 +82,9 @@ using EdgeTupleCounter = Counter; // FuncGraphManagerPtr: return created manager FuncGraphManagerPtr Manage(FuncGraphPtr func_graph, bool manage = true); -FuncGraphManagerPtr Manage(const std::vector& func_graphs, bool manage = true); +FuncGraphManagerPtr Manage(const std::vector &func_graphs, bool manage = true); -FuncGraphManagerPtr MakeManager(const std::vector& func_graphs = {}, bool manage = true); +FuncGraphManagerPtr MakeManager(const std::vector &func_graphs = {}, bool manage = true); struct Signals { Signal AddFuncGraph; @@ -106,7 +106,7 @@ using FuncGraphToAnfNodeCounterMap = OrderedMap; // graphs analysis which compute in write, read needn't recompute class DepCollector : public FuncGraphAnalysis { public: - explicit DepCollector(const FuncGraphManager* manager); + explicit DepCollector(const FuncGraphManager *manager); ~DepCollector() override = default; void Reset() { ExtraReset(); } @@ -155,10 +155,10 @@ class DepCollector : public FuncGraphAnalysis { class NodesCollector final : public DepCollector { public: - explicit NodesCollector(const FuncGraphManager* m); + explicit NodesCollector(const FuncGraphManager *m); ~NodesCollector() override = default; - const FuncGraphToAnfNodeMap& nodes_analysis() const { return nodes_analysis_; } + const FuncGraphToAnfNodeMap &nodes_analysis() const { return nodes_analysis_; } size_t size() const override { return nodes_analysis_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) override { nodes_analysis_[fg] = AnfNodeSet(); } @@ -176,16 +176,16 @@ class NodesCollector final : public DepCollector { class CounterFuncGraphCollector : public DepCollector { public: - explicit CounterFuncGraphCollector(const FuncGraphManager* m) : DepCollector(m) {} + explicit CounterFuncGraphCollector(const FuncGraphManager *m) : DepCollector(m) {} ~CounterFuncGraphCollector() override = default; - FuncGraphToFuncGraphCounterMap& count_func_graphs_map() { return count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &count_func_graphs_map() { return count_func_graphs_map_; } // inherit from FuncGraphAnalysis size_t size() const override { return count_func_graphs_map_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_func_graphs_map_[fg] = OrderedMap(); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_func_graphs_map_.erase(fg); } - bool Inc(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); - bool Dec(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); - bool Mod(const FuncGraphPtr& func_graph, const FuncGraphPtr& key, int count); + bool Inc(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const FuncGraphPtr &key, int count); FuncGraphToFuncGraphCounterMap count_func_graphs_map_; @@ -195,17 +195,17 @@ class CounterFuncGraphCollector : public DepCollector { class CounterAnfNodeCollector : public DepCollector { public: - explicit CounterAnfNodeCollector(const FuncGraphManager* m) : DepCollector(m) {} + explicit CounterAnfNodeCollector(const FuncGraphManager *m) : DepCollector(m) {} ~CounterAnfNodeCollector() override = default; - FuncGraphToAnfNodeCounterMap& count_nodes_map() { return count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &count_nodes_map() { return count_nodes_map_; } size_t size() const override { return count_nodes_map_.size(); } void OnAddFuncGraph(FuncGraphPtr fg) final { count_nodes_map_[fg] = OrderedMap(); } void OnDropFuncGraph(FuncGraphPtr fg) final { (void)count_nodes_map_.erase(fg); } - bool Inc(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); - bool Dec(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); - bool Mod(const FuncGraphPtr& func_graph, const AnfNodePtr& key, int count); + bool Inc(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); + bool Dec(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); + bool Mod(const FuncGraphPtr &func_graph, const AnfNodePtr &key, int count); FuncGraphToAnfNodeCounterMap count_nodes_map_; @@ -215,7 +215,7 @@ class CounterAnfNodeCollector : public DepCollector { class ValueNodesCollector final : public CounterAnfNodeCollector { public: - explicit ValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit ValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~ValueNodesCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -225,7 +225,7 @@ class ValueNodesCollector final : public CounterAnfNodeCollector { class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { public: - explicit FuncGraphValueNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FuncGraphValueNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~FuncGraphValueNodesCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -235,7 +235,7 @@ class FuncGraphValueNodesCollector final : public CounterAnfNodeCollector { class FVDirectCollector final : public CounterAnfNodeCollector { public: - explicit FVDirectCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FVDirectCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} ~FVDirectCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -245,7 +245,7 @@ class FVDirectCollector final : public CounterAnfNodeCollector { class FuncGraphChildDirect final : public CounterFuncGraphCollector { public: - explicit FuncGraphChildDirect(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphChildDirect(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphChildDirect() override = default; @@ -260,7 +260,7 @@ class FuncGraphChildDirect final : public CounterFuncGraphCollector { // 2.direct parent: if graph g's node a used free_variable node in graph f, g's direct parent is f key is g, value is f class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphParentsDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphParentsDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} ~FuncGraphParentsDirectCollector() override = default; void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; @@ -271,7 +271,7 @@ class FuncGraphParentsDirectCollector final : public CounterFuncGraphCollector { // graph's all used graphs: key is g, value is g used graph class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphsUsedCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphsUsedCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphsUsedCollector() override = default; @@ -282,7 +282,7 @@ class FuncGraphsUsedCollector final : public CounterFuncGraphCollector { // graph's all user graphs: key is g, value is graphs who used g class FuncGraphUsersCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphUsersCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphUsersCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphUsersCollector() override = default; @@ -293,7 +293,7 @@ class FuncGraphUsersCollector final : public CounterFuncGraphCollector { // graph's all user cnodes: key is g, value is cnodes who used g class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { public: - explicit FuncGraphUserNodesCollector(const FuncGraphManager* m) : CounterAnfNodeCollector(m) {} + explicit FuncGraphUserNodesCollector(const FuncGraphManager *m) : CounterAnfNodeCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, FuncGraphPtr dst) override; ~FuncGraphUserNodesCollector() override = default; @@ -303,7 +303,7 @@ class FuncGraphUserNodesCollector final : public CounterAnfNodeCollector { class FuncGraphJDirectCollector final : public CounterFuncGraphCollector { public: - explicit FuncGraphJDirectCollector(const FuncGraphManager* m) : CounterFuncGraphCollector(m) {} + explicit FuncGraphJDirectCollector(const FuncGraphManager *m) : CounterFuncGraphCollector(m) {} void OnMoveAllCNode(FuncGraphPtr src, const FuncGraphPtr dst) override; ~FuncGraphJDirectCollector() override = default; @@ -316,7 +316,7 @@ using FuncGraphToFuncGraphSetMap = OrderedMap; // graphs analysis which need dynamic compute by DepCollector in each read class DepComputer : public FuncGraphAnalysis { public: - explicit DepComputer(const FuncGraphManager* manager); + explicit DepComputer(const FuncGraphManager *manager); ~DepComputer() override = default; void Reset() { @@ -329,11 +329,11 @@ class DepComputer : public FuncGraphAnalysis { void Recompute(); - void Recompute(const FuncGraphPtr& fg); + void Recompute(const FuncGraphPtr &fg); bool IsValidate() const { return validate_; } - bool IsValidate(const FuncGraphPtr& fg) { return func_graphs_validate_[fg]; } + bool IsValidate(const FuncGraphPtr &fg) { return func_graphs_validate_[fg]; } void OnAddFuncGraph(FuncGraphPtr) final { Reset(); } @@ -354,10 +354,10 @@ class DepComputer : public FuncGraphAnalysis { // graph g's all direct or proxy parents class FuncGraphParentsTotalComputer final : public DepComputer { public: - explicit FuncGraphParentsTotalComputer(const FuncGraphManager* m) : DepComputer(m), all_parents_direct_(nullptr) {} + explicit FuncGraphParentsTotalComputer(const FuncGraphManager *m) : DepComputer(m), all_parents_direct_(nullptr) {} ~FuncGraphParentsTotalComputer() override { all_parents_direct_ = nullptr; } - FuncGraphToFuncGraphSetMap& func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } + FuncGraphToFuncGraphSetMap &func_graph_parents_total_analysis() { return func_graph_parents_total_analysis_; } size_t size() const override { return func_graph_parents_total_analysis_.size(); } @@ -369,10 +369,10 @@ class FuncGraphParentsTotalComputer final : public DepComputer { void RealRecompute(FuncGraphPtr fg) override; private: - FuncGraphSetPtr SeekParents(const FuncGraphPtr& fg, const FuncGraphSetPtr& path = std::make_shared()); + FuncGraphSetPtr SeekParents(const FuncGraphPtr &fg, const FuncGraphSetPtr &path = std::make_shared()); // when SeekParents calls itself recursively, it can access these variables by class member // other than pass by formal parameters, it can save 1 parameter for SeekParents(). - FuncGraphToFuncGraphCounterMap* all_parents_direct_; + FuncGraphToFuncGraphCounterMap *all_parents_direct_; }; using FuncGraphToFuncGraphMap = OrderedMap; @@ -380,10 +380,10 @@ using FuncGraphToFuncGraphMap = OrderedMap; // graph's nearest parent in parents total class ParentComputer final : public DepComputer { public: - explicit ParentComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ParentComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ParentComputer() override = default; - FuncGraphToFuncGraphMap& parent_analysis() { return parent_analysis_; } + FuncGraphToFuncGraphMap &parent_analysis() { return parent_analysis_; } size_t size() const override { return parent_analysis_.size(); } @@ -398,10 +398,10 @@ class ParentComputer final : public DepComputer { // graph's children graph except self class ChildrenComputer final : public DepComputer { public: - explicit ChildrenComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ChildrenComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ChildrenComputer() override = default; - FuncGraphToFuncGraphSetMap& children_analysis() { return children_analysis_; } + FuncGraphToFuncGraphSetMap &children_analysis() { return children_analysis_; } size_t size() const override { return children_analysis_.size(); } @@ -416,10 +416,10 @@ class ChildrenComputer final : public DepComputer { // graph's children graph include self class ScopeComputer final : public DepComputer { public: - explicit ScopeComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit ScopeComputer(const FuncGraphManager *m) : DepComputer(m) {} ~ScopeComputer() override = default; - FuncGraphToFuncGraphSetMap& scope_analysis() { return scope_analysis_; } + FuncGraphToFuncGraphSetMap &scope_analysis() { return scope_analysis_; } size_t size() const override { return scope_analysis_.size(); } @@ -435,11 +435,11 @@ using FVTotalMap = OrderedMap* trace); + void CheckRecursiveGraphs(const FuncGraphPtr &fg, std::list *trace); size_t size() const override { return recursive_analysis_.size(); } @@ -497,10 +497,10 @@ class RecursiveComputer final : public DepComputer { class FuncGraphJTotalComputer final : public DepComputer { public: - explicit FuncGraphJTotalComputer(const FuncGraphManager* m) : DepComputer(m) {} + explicit FuncGraphJTotalComputer(const FuncGraphManager *m) : DepComputer(m) {} ~FuncGraphJTotalComputer() override = default; - FuncGraphToBoolMap& j_total_analysis() { return j_total_analysis_; } + FuncGraphToBoolMap &j_total_analysis() { return j_total_analysis_; } size_t size() const override { return j_total_analysis_.size(); } @@ -510,12 +510,12 @@ class FuncGraphJTotalComputer final : public DepComputer { void ExtraReset() override { j_total_analysis_.clear(); } void RealRecompute(FuncGraphPtr fg) override; - bool SeekJ(const FuncGraphPtr& fg, const FuncGraphSetPtr& path); + bool SeekJ(const FuncGraphPtr &fg, const FuncGraphSetPtr &path); }; class FuncGraphManager : public std::enable_shared_from_this { public: - explicit FuncGraphManager(const std::vector& roots, bool manage = true); + explicit FuncGraphManager(const std::vector &roots, bool manage = true); ~FuncGraphManager() { if (is_manage_) { RemoveRoots(); @@ -526,71 +526,71 @@ class FuncGraphManager : public std::enable_shared_from_this { void Init(); void Clear(); void AddFuncGraph(FuncGraphPtr func_graph, bool is_root = false); - void KeepRoots(const std::vector& roots = {}); + void KeepRoots(const std::vector &roots = {}); void RemoveRoots(); - void SetParameters(const FuncGraphPtr& fg, const std::vector& parameters); - void MaybeDropFuncGraphs(const FuncGraphSet& func_graphs, bool ignore_users = false); - bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); - void SetEdge(const AnfNodePtr& node, int index, const AnfNodePtr& value); - void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr& scope); + void SetParameters(const FuncGraphPtr &fg, const std::vector ¶meters); + void MaybeDropFuncGraphs(const FuncGraphSet &func_graphs, bool ignore_users = false); + bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); + void SetEdge(const AnfNodePtr &node, int index, const AnfNodePtr &value); + void MoveAllCNodeDropGraph(FuncGraphPtr source, FuncGraphPtr target, const ScopePtr &scope); FuncGraphTransaction Transact(); - void CommitChanges(const std::vector& changes); + void CommitChanges(const std::vector &changes); bool IsManaged() const { return is_manage_; } - const FuncGraphSet& roots() const { return roots_; } + const FuncGraphSet &roots() const { return roots_; } - const FuncGraphSet& func_graphs() const { return func_graphs_; } + const FuncGraphSet &func_graphs() const { return func_graphs_; } - AnfNodeSet& all_nodes() { return all_nodes_; } + AnfNodeSet &all_nodes() { return all_nodes_; } - NodeUsersMap& node_users() { return node_users_; } + NodeUsersMap &node_users() { return node_users_; } - FuncGraphToAnfNodeMap& nodes() const { return nodes_->nodes_analysis_; } + FuncGraphToAnfNodeMap &nodes() const { return nodes_->nodes_analysis_; } - FuncGraphToAnfNodeCounterMap& valuenodes() const { return valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &valuenodes() const { return valuenodes_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap& free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &free_variables_direct() const { return free_variables_direct_->count_nodes_map_; } - FuncGraphToAnfNodeCounterMap& func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_valuenodes() const { return func_graph_valuenodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap& func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graphs_used() const { return func_graphs_used_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graph_users() const { return func_graph_users_->count_func_graphs_map_; } - FuncGraphToAnfNodeCounterMap& func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } + FuncGraphToAnfNodeCounterMap &func_graph_user_cnodes() const { return func_graph_user_cnodes_->count_nodes_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_child_direct() const { + FuncGraphToFuncGraphCounterMap &func_graph_child_direct() const { return func_graph_child_direct_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_parents_direct() const { + FuncGraphToFuncGraphCounterMap &func_graph_parents_direct() const { return func_graph_parents_direct_->count_func_graphs_map_; } - FuncGraphToFuncGraphCounterMap& func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } + FuncGraphToFuncGraphCounterMap &func_graph_j_direct() const { return func_graph_j_direct_->count_func_graphs_map_; } - FVTotalMap& free_variables_total() const; + FVTotalMap &free_variables_total() const; - FuncGraphSet& func_graph_parents_total(const FuncGraphPtr& fg) const; + FuncGraphSet &func_graph_parents_total(const FuncGraphPtr &fg) const; - FuncGraphSet& scopes(const FuncGraphPtr& fg) const; + FuncGraphSet &scopes(const FuncGraphPtr &fg) const; - FuncGraphPtr parent(const FuncGraphPtr& fg) const; + FuncGraphPtr parent(const FuncGraphPtr &fg) const; - FuncGraphSet& children(const FuncGraphPtr& fg) const; + FuncGraphSet &children(const FuncGraphPtr &fg) const; - FuncGraphSet& func_graphs_used_total(const FuncGraphPtr& fg) const; + FuncGraphSet &func_graphs_used_total(const FuncGraphPtr &fg) const; - bool recursive(const FuncGraphPtr& fg) const; - std::shared_ptr> recursive_graphs(const FuncGraphPtr& fg) const; + bool recursive(const FuncGraphPtr &fg) const; + std::shared_ptr> recursive_graphs(const FuncGraphPtr &fg) const; - bool func_graph_j_total(const FuncGraphPtr& fg) const; + bool func_graph_j_total(const FuncGraphPtr &fg) const; std::shared_ptr signals() const { return signals_; } - IncludeType Limit(const AnfNodePtr& node); + IncludeType Limit(const AnfNodePtr &node); // Static Analysis NodeUsersMap node_users_; @@ -610,13 +610,13 @@ class FuncGraphManager : public std::enable_shared_from_this { std::shared_ptr func_graph_parent_; private: - void AddIntoManaged(const FuncGraphPtr& fg); + void AddIntoManaged(const FuncGraphPtr &fg); void ProcessEdge(AnfNodePtr node, int index, AnfNodePtr inp, EdgeProcessDirection direction); - void ProcessInputs(const AnfNodePtr& node, EdgeProcessDirection direction); - void AcquireNodes(const std::vector& nodes); - FuncGraphSetPtr MaybeDropNodes(const std::vector& nodes); - void ParseChanges(const std::vector& changes, EdgeTupleCounter* add_edges, EdgeTupleCounter* rm_edges, - Counter* adds, Counter* rms); + void ProcessInputs(const AnfNodePtr &node, EdgeProcessDirection direction); + void AcquireNodes(const std::vector &nodes); + FuncGraphSetPtr MaybeDropNodes(const std::vector &nodes); + void ParseChanges(const std::vector &changes, EdgeTupleCounter *add_edges, EdgeTupleCounter *rm_edges, + Counter *adds, Counter *rms); FuncGraphSet roots_; // managed roots FuncGraphSet func_graphs_; // managed func graphs @@ -637,7 +637,7 @@ class FuncGraphManager : public std::enable_shared_from_this { class FuncGraphTransaction { public: - explicit FuncGraphTransaction(FuncGraphManager* manager) : manager_(manager), changes_() { + explicit FuncGraphTransaction(FuncGraphManager *manager) : manager_(manager), changes_() { MS_EXCEPTION_IF_NULL(manager_); if (!manager_->IsManaged()) { MS_LOG(DEBUG) << "The manager is not managed yet"; @@ -648,19 +648,19 @@ class FuncGraphTransaction { ~FuncGraphTransaction() { manager_ = nullptr; } // set parameters of a func graph - void SetParameters(FuncGraphPtr fg, const std::vector& params); + void SetParameters(FuncGraphPtr fg, const std::vector ¶ms); // replace old_node with new_node - bool Replace(const AnfNodePtr& old_node, const AnfNodePtr& new_node); + bool Replace(const AnfNodePtr &old_node, const AnfNodePtr &new_node); // set esge, i.e., declare setting node.inputs[key] to value. - void SetEdge(const AnfNodePtr& src_node, int k, const AnfNodePtr& v); + void SetEdge(const AnfNodePtr &src_node, int k, const AnfNodePtr &v); // commit all changes void Commit(); private: - FuncGraphManager* manager_; + FuncGraphManager *manager_; std::vector changes_; }; @@ -668,9 +668,9 @@ class FuncGraphTransaction { struct ArgsOfSetParams { FuncGraphPtr func_graph; std::vector params; - bool operator==(const ArgsOfSetParams& other) const { return &other == this; } + bool operator==(const ArgsOfSetParams &other) const { return &other == this; } - friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetParams&) { + friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetParams &) { os << "[ArgsOfSetParams]"; return os; } @@ -681,9 +681,9 @@ struct ArgsOfSetEdge { CNodePtr root_node; AnfNodePtr new_node; size_t index; - bool operator==(const ArgsOfSetEdge& other) const { return &other == this; } + bool operator==(const ArgsOfSetEdge &other) const { return &other == this; } - friend std::ostream& operator<<(std::ostream& os, const ArgsOfSetEdge& other) { + friend std::ostream &operator<<(std::ostream &os, const ArgsOfSetEdge &other) { os << "[ArgsOfSetEdge]"; return os; } @@ -693,7 +693,7 @@ struct Change { enum OpName { kTxSetParams, kTxSetEdge }; OpName op; Any args; - Change(OpName name, const Any& para) : op(name), args(para) {} + Change(OpName name, const Any ¶) : op(name), args(para) {} }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/meta_func_graph.h b/mindspore/ccsrc/ir/meta_func_graph.h index 69da925e3d..482b5f9025 100644 --- a/mindspore/ccsrc/ir/meta_func_graph.h +++ b/mindspore/ccsrc/ir/meta_func_graph.h @@ -42,25 +42,25 @@ namespace mindspore { // generate a graph corresponding to these types. class MetaFuncGraph : public FuncGraphBase { public: - explicit MetaFuncGraph(const std::string& name) : name_(name) { cache_.clear(); } + explicit MetaFuncGraph(const std::string &name) : name_(name) { cache_.clear(); } ~MetaFuncGraph() override = default; MS_DECLARE_PARENT(MetaFuncGraph, FuncGraphBase); - abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr& anf_node); + abstract::AbstractBasePtr MakeAbstractClosure(const AnfNodePtr &anf_node); // Return normalized versions of the arguments. // By default, this returns args unchanged. - virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const { + virtual abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const { return args_spec_list; } - const std::vector& signatures() const { return signatures_; } - void set_signatures(const std::vector& signatures) { signatures_ = signatures; } + const std::vector &signatures() const { return signatures_; } + void set_signatures(const std::vector &signatures) { signatures_ = signatures; } // Generate a Graph for the given abstract arguments. - virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) { + virtual FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) { TypePtrList types; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types), - [](const AbstractBasePtr& arg) -> TypePtr { + [](const AbstractBasePtr &arg) -> TypePtr { MS_EXCEPTION_IF_NULL(arg); return arg->BuildType(); }); @@ -81,7 +81,7 @@ class MetaFuncGraph : public FuncGraphBase { } // Generate a Graph for this type signature. - virtual FuncGraphPtr GenerateFromTypes(const TypePtrList&) { + virtual FuncGraphPtr GenerateFromTypes(const TypePtrList &) { MS_LOG(EXCEPTION) << "Undefine the method of generating graph from types."; } @@ -89,8 +89,8 @@ class MetaFuncGraph : public FuncGraphBase { std::string ToString() const override { return name_; } std::size_t hash() const override { return tid(); } - virtual bool operator==(const MetaFuncGraph& other) const { return &other == this; } - bool operator==(const Value& other) const override { + virtual bool operator==(const MetaFuncGraph &other) const { return &other == this; } + bool operator==(const Value &other) const override { if (other.isa()) { return &other == this; } else { diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/ccsrc/ir/meta_tensor.cc index e9221039a7..5bb9ae3c06 100644 --- a/mindspore/ccsrc/ir/meta_tensor.cc +++ b/mindspore/ccsrc/ir/meta_tensor.cc @@ -31,7 +31,7 @@ namespace mindspore { namespace tensor { -void DataBuf2Contiguous(const py::array& src, py::array* const dest) { +void DataBuf2Contiguous(const py::array &src, py::array *const dest) { if (dest == nullptr) { MS_LOG(EXCEPTION) << "Failed to copy data to a contiguous buffer as dest is nullptr!"; } @@ -55,9 +55,9 @@ void DataBuf2Contiguous(const py::array& src, py::array* const dest) { // MetaTensor has default type_id_ which is TypeId::kTypeUnknown. MetaTensor::MetaTensor() : data_type_(TypeId::kTypeUnknown) {} -MetaTensor::MetaTensor(const TypeId data_type, const std::vector& shape) : data_type_(data_type), shape_(shape) {} +MetaTensor::MetaTensor(const TypeId data_type, const std::vector &shape) : data_type_(data_type), shape_(shape) {} -MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { +MetaTensor::MetaTensor(const TypePtr &type_ptr, const py::tuple &shape) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -69,10 +69,10 @@ MetaTensor::MetaTensor(const TypePtr& type_ptr, const py::tuple& shape) { } } -MetaTensor::MetaTensor(const MetaTensor& meta_tensor) +MetaTensor::MetaTensor(const MetaTensor &meta_tensor) : Value(meta_tensor), data_type_(meta_tensor.data_type()), shape_(meta_tensor.shape()) {} -MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { +MetaTensor &MetaTensor::operator=(const MetaTensor &meta_tensor) { if (&meta_tensor == this) { return *this; } @@ -84,7 +84,7 @@ MetaTensor& MetaTensor::operator=(const MetaTensor& meta_tensor) { return *this; } -bool MetaTensor::operator==(const MetaTensor& meta_tensor) const { +bool MetaTensor::operator==(const MetaTensor &meta_tensor) const { return data_type_ == meta_tensor.data_type() && shape_ == meta_tensor.shape(); } @@ -117,7 +117,7 @@ TypePtr MetaTensor::SetDtype(const TypePtr type_ptr) { return type_ptr; } -void MetaTensor::SetDeviceInfo(const std::string& format, const TypePtr& data_type) { +void MetaTensor::SetDeviceInfo(const std::string &format, const TypePtr &data_type) { DeviceInfo info(format, data_type); set_device_info(info); } @@ -138,7 +138,7 @@ std::string MetaTensor::DumpText() const { return oss.str(); } -Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { +Tensor::Tensor(const TypePtr &type_ptr, const py::tuple &shape) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -151,24 +151,24 @@ Tensor::Tensor(const TypePtr& type_ptr, const py::tuple& shape) { init(data_type_, shape_, &data_); } -Tensor::Tensor(TypeId data_type, const std::vector& shape) { init(data_type, shape, &data_); } +Tensor::Tensor(TypeId data_type, const std::vector &shape) { init(data_type, shape, &data_); } -Tensor::Tensor(const py::array& input, const TypePtr& data_type) { init(input, data_type); } +Tensor::Tensor(const py::array &input, const TypePtr &data_type) { init(input, data_type); } -Tensor::Tensor(const py::list& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::list &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::tuple& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::tuple &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::float_& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const py::int_& input, const TypePtr& data_type) { init(py::array(input), data_type); } +Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } -Tensor::Tensor(const Tensor& tensor, const TypePtr& data_type) +Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type) : MetaTensor(tensor), device_address_(tensor.device_address()) { init(tensor.data_, data_type); } -Tensor& Tensor::operator=(const Tensor& tensor) { +Tensor &Tensor::operator=(const Tensor &tensor) { if (this != &tensor) { MetaTensor::operator=(tensor); dirty_ = tensor.is_dirty(); @@ -178,11 +178,11 @@ Tensor& Tensor::operator=(const Tensor& tensor) { return *this; } -bool Tensor::operator==(const Tensor& tensor) const { +bool Tensor::operator==(const Tensor &tensor) const { return (MetaTensor::operator==(tensor) && data_ == tensor.data_); } -bool Tensor::ValueEqualPy(const py::object& other) const { +bool Tensor::ValueEqualPy(const py::object &other) const { if (!py::isinstance(other)) { MS_LOG(WARNING) << "compare other not a tensor"; return false; @@ -190,7 +190,7 @@ bool Tensor::ValueEqualPy(const py::object& other) const { return ValueEqual(py::cast(other)); } -bool Tensor::ValueEqual(const Tensor& other) const { +bool Tensor::ValueEqual(const Tensor &other) const { auto equal = [&other, this]() -> bool { auto np = py::module::import("numpy"); auto equal = np.attr("equal")(data_, other.data_); @@ -218,7 +218,7 @@ int Tensor::data_type_c() const { return static_cast(data_type_); } std::vector Tensor::shape_c(void) const { return shape(); } -void* Tensor::data_c(bool writable) { +void *Tensor::data_c(bool writable) { // operand of bit operation should be unsigned int. unsigned int flags = ((unsigned int)data_.flags()) & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_; bool is_c_contiguous = (flags != 0) ? true : false; @@ -231,7 +231,7 @@ void* Tensor::data_c(bool writable) { return data_.request(writable).ptr; } -TypeId Tensor::GetDataType(const py::buffer_info& buf) const { +TypeId Tensor::GetDataType(const py::buffer_info &buf) const { TypeId data_type = TypeId::kTypeUnknown; if (buf.format.compare("e") == 0) { data_type = TypeId::kNumberTypeFloat16; @@ -263,7 +263,7 @@ TypeId Tensor::GetDataType(const py::buffer_info& buf) const { return data_type; } -void Tensor::init(const py::array& input, const TypePtr& type_ptr) { +void Tensor::init(const py::array &input, const TypePtr &type_ptr) { TypeId data_type = TypeId::kTypeUnknown; if (type_ptr != nullptr) { data_type = type_ptr->type_id(); @@ -271,7 +271,7 @@ void Tensor::init(const py::array& input, const TypePtr& type_ptr) { init(input, data_type); } -void Tensor::init(const py::array& input, const TypeId& data_type) { +void Tensor::init(const py::array &input, const TypeId &data_type) { py::buffer_info buf = input.request(); data_type_ = GetDataType(buf); @@ -301,7 +301,7 @@ void Tensor::init(const py::array& input, const TypeId& data_type) { } } -void Tensor::init(TypeId data_type, const std::vector& shape, py::array* const data) { +void Tensor::init(TypeId data_type, const std::vector &shape, py::array *const data) { data_type_ = data_type; shape_ = shape; switch (data_type) { @@ -368,7 +368,7 @@ TypeId Tensor::set_data_type(const TypeId data_type) { return data_type_; } -bool Tensor::convert_data(const py::array& in, const TypeId in_data_type, py::array* const out, +bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out, const TypeId out_data_type) { if (out == nullptr) { return false; @@ -458,7 +458,7 @@ py::array Tensor::data_sync() { return data_; } -REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { // dtype should define before Tensor, because Tensor init depend dtype (void)py::class_>(*m, "Tensor") .def(py::init(), py::arg("dtype"), py::arg("shape")) @@ -541,11 +541,11 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module* m) { .def("__repr__", &Tensor::ToStringRepr) .def("__eq__", &Tensor::ValueEqualPy) .def(py::pickle( - [](const Tensor& t) { // __getstate__ + [](const Tensor &t) { // __getstate__ /* Return a tuple that fully encodes the state of the object */ return py::make_tuple(t.data()); }, - [](const py::tuple& t) { // __setstate__ + [](const py::tuple &t) { // __setstate__ if (t.size() != 1) { throw std::runtime_error("Invalid state!"); } diff --git a/mindspore/ccsrc/ir/meta_tensor.h b/mindspore/ccsrc/ir/meta_tensor.h index 3e28f29f37..1f6c866f11 100644 --- a/mindspore/ccsrc/ir/meta_tensor.h +++ b/mindspore/ccsrc/ir/meta_tensor.h @@ -131,16 +131,16 @@ class MetaTensor : public Value { // information of a Tensor. The following codes will create a 2x3 float // param data_type The data type of the tensor. // param shape The shape of the tensor. - MetaTensor(const TypeId data_type, const std::vector& shape); + MetaTensor(const TypeId data_type, const std::vector &shape); - MetaTensor(const TypePtr& type_ptr, const py::tuple& shape); + MetaTensor(const TypePtr &type_ptr, const py::tuple &shape); // brief Constructs a MetaTensor object from an existing MetaTensor instance. // // The constructed MetaTensor object will have the same data type and shape as the // meta_tensor. // // param meta_tensor An existing MetaTensor object. - MetaTensor(const MetaTensor& meta_tensor); + MetaTensor(const MetaTensor &meta_tensor); ~MetaTensor() override = default; MS_DECLARE_PARENT(MetaTensor, Value) @@ -149,7 +149,7 @@ class MetaTensor : public Value { // The constructed MetaTensor object has the same type and shape with meta_tensor. // // param meta_tensor An existing MetaTensor object. - virtual MetaTensor& operator=(const MetaTensor& meta_tensor); + virtual MetaTensor &operator=(const MetaTensor &meta_tensor); // brief Compares two MetaTensor objects. // @@ -157,7 +157,7 @@ class MetaTensor : public Value { // // param meta_tensor The MetaTensor object to be compared. // return true: If having same type and shape, return true, or return false. - virtual bool operator==(const MetaTensor& meta_tensor) const; + virtual bool operator==(const MetaTensor &meta_tensor) const; // brief Returns the data type of the tensor in its MetaTensor. // @@ -193,7 +193,7 @@ class MetaTensor : public Value { // // param shape The shape of the tensor. // return The shape's size. - size_t set_shape(const std::vector& shape) { + size_t set_shape(const std::vector &shape) { this->shape_ = shape; return shape_.size(); } @@ -202,9 +202,9 @@ class MetaTensor : public Value { DeviceInfo device_info() const { return device_info_; } // Set tensor's device info. - void set_device_info(const DeviceInfo& device_info) { device_info_ = device_info; } + void set_device_info(const DeviceInfo &device_info) { device_info_ = device_info; } - void SetDeviceInfo(const std::string& format, const TypePtr& data_type); + void SetDeviceInfo(const std::string &format, const TypePtr &data_type); // Get the size of a given dimension by its index number. int DimensionSize(size_t index) const; @@ -222,9 +222,9 @@ class MetaTensor : public Value { } return hash_value; } - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -262,49 +262,49 @@ class Tensor : public MetaTensor { // // param type_ptr [TypePty] Data type of the tensor. // param py_shape [py::tuple] The shape represented by py::tuple of the tensor. - Tensor(const TypePtr& type_ptr, const py::tuple& shape); + Tensor(const TypePtr &type_ptr, const py::tuple &shape); // brief Constructor for C++. // // param data_type [TypeId] Data type of the tensor. // param shape The shape represented by std::vector of the tensor. - Tensor(TypeId data_type, const std::vector& shape); + Tensor(TypeId data_type, const std::vector &shape); // brief Constructor for Python. // // param input [py::array] Data value of the tensor. // param data_type [TypeId] Data type of the tensor. - explicit Tensor(const py::array& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::array &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::list] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::list& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::list &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::tuple] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::tuple& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::tuple &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::float_] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::float_& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::float_ &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [py::int_] the data for tensor // param data_type [TypeId] data type - explicit Tensor(const py::int_& input, const TypePtr& data_type = nullptr); + explicit Tensor(const py::int_ &input, const TypePtr &data_type = nullptr); // brief Constructor // // param input [Tensor] the data for tensor // param data_type [TypeId] data type - Tensor(const Tensor& tensor, const TypePtr& data_type = nullptr); + Tensor(const Tensor &tensor, const TypePtr &data_type = nullptr); ~Tensor() override = default; @@ -315,7 +315,7 @@ class Tensor : public MetaTensor { // The constructed Tensor object has the same type and shape with tensor. // // param tensor An existing Tensor object. - Tensor& operator=(const Tensor& tensor); + Tensor &operator=(const Tensor &tensor); // brief Compares two Tensor objects. // @@ -324,17 +324,17 @@ class Tensor : public MetaTensor { // // param tensor The Tensor object to be compared. // return true: If having same type, shape and data, return true, or return false. - bool operator==(const Tensor& tensor) const; + bool operator==(const Tensor &tensor) const; // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqual(const Tensor& other) const; + bool ValueEqual(const Tensor &other) const; // It is different from 'operator==' which just compare shape/type/address, it do real value comparison. - bool ValueEqualPy(const py::object& other) const; + bool ValueEqualPy(const py::object &other) const; - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -375,13 +375,13 @@ class Tensor : public MetaTensor { // // param writable true if writable, false if read only // return The pointer to the object - void* data_c(bool writable = false); + void *data_c(bool writable = false); // brief Get data type from tensor data. // // param buf The buffer info of the py::array data. // return The [TypeId] of the tensor data. - TypeId GetDataType(const py::buffer_info& buf) const; + TypeId GetDataType(const py::buffer_info &buf) const; // brief Sets the data type of a tensor. // @@ -401,23 +401,23 @@ class Tensor : public MetaTensor { // param input [py::array] the data for tensor // param data_type [TypeId] data type // return true if succeed, false if failed. - void init(const py::array& input, const TypeId& data_type); - void init(const py::array& input, const TypePtr& type_ptr); + void init(const py::array &input, const TypeId &data_type); + void init(const py::array &input, const TypePtr &type_ptr); // brief init tensor attribute // // param data_type [TypeId] Data type of the tensor. // param shape [py::array] The shape of the tensor. // return true if succeed, false if failed. - void init(TypeId data_type, const std::vector& shape, py::array* data); + void init(TypeId data_type, const std::vector &shape, py::array *data); - bool convert_data(const py::array& in, const TypeId in_data_type, py::array* out, const TypeId out_data_type); + bool convert_data(const py::array &in, const TypeId in_data_type, py::array *out, const TypeId out_data_type); public: bool is_dirty() const { return dirty_; } void set_dirty(const bool dirty) { dirty_ = dirty; } DeviceAddressPtr device_address() const { return device_address_; } - void set_device_address(const DeviceAddressPtr& device_address) { device_address_ = device_address; } + void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } py::array data_sync(); private: diff --git a/mindspore/ccsrc/ir/named.cc b/mindspore/ccsrc/ir/named.cc index 3d12e8a453..67e11c64d3 100644 --- a/mindspore/ccsrc/ir/named.cc +++ b/mindspore/ccsrc/ir/named.cc @@ -18,9 +18,9 @@ #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { -bool Named::operator==(const Value& other) const { +bool Named::operator==(const Value &other) const { if (other.isa()) { - auto other_named = static_cast(other); + auto other_named = static_cast(other); return *this == other_named; } else { return false; diff --git a/mindspore/ccsrc/ir/named.h b/mindspore/ccsrc/ir/named.h index 0651307a91..76136fb298 100644 --- a/mindspore/ccsrc/ir/named.h +++ b/mindspore/ccsrc/ir/named.h @@ -27,18 +27,18 @@ namespace mindspore { class Named : public Value { public: - explicit Named(const std::string& name) : name_(name) { hash_id_ = std::hash{}(name); } - Named(const Named& other) : Value(other) { + explicit Named(const std::string &name) : name_(name) { hash_id_ = std::hash{}(name); } + Named(const Named &other) : Value(other) { this->name_ = other.name_; hash_id_ = std::hash{}(other.name_); } ~Named() override = default; MS_DECLARE_PARENT(Named, Value); - const std::string& name() const { return name_; } - virtual bool operator==(const Named& other) const { return name_ == other.name(); } - bool operator==(const Value& other) const override; - Named& operator=(const Named& other) { + const std::string &name() const { return name_; } + virtual bool operator==(const Named &other) const { return name_ == other.name(); } + bool operator==(const Value &other) const override; + Named &operator=(const Named &other) { if (&other != this) { this->type_ = other.type_; this->name_ = other.name_; @@ -50,7 +50,7 @@ class Named : public Value { std::size_t Hash() const { return hash_id_; } std::size_t hash() const override { return hash_id_; } - friend std::ostream& operator<<(std::ostream& os, const Named& nmd) { + friend std::ostream &operator<<(std::ostream &os, const Named &nmd) { os << nmd.name(); return os; } diff --git a/mindspore/ccsrc/ir/primitive.cc b/mindspore/ccsrc/ir/primitive.cc index a576c1e76b..d40f8a265d 100644 --- a/mindspore/ccsrc/ir/primitive.cc +++ b/mindspore/ccsrc/ir/primitive.cc @@ -31,7 +31,7 @@ namespace mindspore { using mindspore::abstract::AbstractFunction; -abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr& anf_node) { +abstract::AbstractBasePtr Primitive::ToPrimAbstract(const AnfNodePtr &anf_node) { auto prim_func = std::make_shared(shared_from_base(), anf_node); return prim_func; } @@ -63,23 +63,23 @@ py::function Primitive::GetComputeFunction() { return fn; } -bool Primitive::operator==(const Value& other) const { +bool Primitive::operator==(const Value &other) const { if (other.isa()) { - auto other_prim = static_cast(other); + auto other_prim = static_cast(other); return *this == other_prim; } else { return false; } } -bool Primitive::operator==(const Primitive& other) const { +bool Primitive::operator==(const Primitive &other) const { if (name() != other.name()) { return false; } if (attrs_.size() != other.attrs_.size()) { return false; } - auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair& item) -> bool { + auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair &item) -> bool { if (item.second == nullptr) { return false; } @@ -95,7 +95,7 @@ bool Primitive::operator==(const Primitive& other) const { void Primitive::set_signatures( std::vector> signatures) { signatures_.clear(); - for (auto& signature : signatures) { + for (auto &signature : signatures) { std::string name; SignatureEnumRW rw; SignatureEnumKind kind; @@ -114,7 +114,7 @@ std::string Primitive::GetAttrsText() const { std::ostringstream oss; oss << "["; bool is_first = true; - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { if (is_first) { is_first = false; } else { @@ -128,7 +128,7 @@ std::string Primitive::GetAttrsText() const { } py::function PrimitivePy::GetBpropFunction() { - static const char* const get_bprop_func_name = "get_bprop"; + static const char *const get_bprop_func_name = "get_bprop"; if (py::hasattr(python_obj_, get_bprop_func_name)) { py::function fn = python_obj_.attr(get_bprop_func_name)().cast(); return fn; @@ -142,7 +142,7 @@ py::function PrimitivePy::GetBpropFunction() { } py::function PrimitivePy::GetComputeFunction() { - static const char* const compute_func_name = "vm_impl"; + static const char *const compute_func_name = "vm_impl"; if (py::hasattr(python_obj_, compute_func_name)) { MS_LOG(INFO) << "" << name() << " compute_func_name"; @@ -163,7 +163,7 @@ py::function PrimitivePy::GetComputeFunction() { return vm_fn; } -void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { +void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { std::string attr_name = name; ValuePtr converted_ret = nullptr; if (py::isinstance(obj)) { @@ -178,13 +178,13 @@ void PrimitivePy::AddPyAttr(const py::str& name, const py::object& obj) { py::dict PrimitivePy::GetAttrDict() { py::dict attr_dict; - for (auto& attr : attrs_) { + for (auto &attr : attrs_) { attr_dict[py::str(attr.first)] = ValuePtrToPyData(attr.second); } return attr_dict; } -REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) .value("builtin", PrimType::kPrimTypeBuiltIn) @@ -192,7 +192,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module* m) { .value("user_custom", PrimType::kPrimTypeUserCustom); (void)py::class_>(*m, "Primitive_") .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) - .def(py::init()) + .def(py::init()) .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index 7dd37eb15f..d16a524f69 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -48,25 +48,25 @@ enum PrimType { class Primitive : public Named { public: - explicit Primitive(const std::string& name, const PrimType prim_type = kPrimTypeBuiltIn) + explicit Primitive(const std::string &name, const PrimType prim_type = kPrimTypeBuiltIn) : Named(name), signatures_(), prim_type_(prim_type) {} - Primitive(const Primitive& prim) + Primitive(const Primitive &prim) : Named(prim), attrs_(prim.attrs_), signatures_(prim.signatures_), prim_type_(prim.prim_type_) {} MS_DECLARE_PARENT(Primitive, Named); - abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr& anf_node); + abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node); std::string ToString() const override { return name(); } virtual py::function GetBpropFunction(); virtual py::function GetComputeFunction(); - Primitive& AddAttr(const std::string& name, const ValuePtr& attr) { + Primitive &AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; return *this; } - Primitive& SetAttrs(const std::unordered_map& attrs) { - for (auto& attr : attrs) { + Primitive &SetAttrs(const std::unordered_map &attrs) { + for (auto &attr : attrs) { attrs_[attr.first] = attr.second; } return *this; @@ -76,21 +76,21 @@ class Primitive : public Named { std::vector> signatures); - const std::vector& signatures() const { return signatures_; } + const std::vector &signatures() const { return signatures_; } - void set_attr(const std::string& attrName, const ValuePtr& attr) { attrs_[attrName] = attr; } - void EraseAttr(const std::string& attrName) { (void)attrs_.erase(attrName); } + void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } + void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } - ValuePtr GetAttr(const std::string& attrName) const { + ValuePtr GetAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); return iter == attrs_.cend() ? nullptr : iter->second; } - const std::unordered_map& attrs() const { return attrs_; } + const std::unordered_map &attrs() const { return attrs_; } // if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute. bool HasAttr() const { return !attrs_.empty(); } - bool HasAttr(const std::string& attrName) const { + bool HasAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); return !(iter == attrs_.cend()); } @@ -103,8 +103,8 @@ class Primitive : public Named { PrimType prim_type() const { return prim_type_; } std::string instance_name() const { return instance_name_; } std::string GetAttrsText() const; - bool operator==(const Value& other) const override; - bool operator==(const Primitive& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Primitive &other) const; ~Primitive() override = default; protected: @@ -118,18 +118,18 @@ class Primitive : public Named { class PrimitivePy : public Primitive { public: - PrimitivePy(const py::str& name, const py::object& python_obj) : Primitive(name), python_obj_(python_obj) {} + PrimitivePy(const py::str &name, const py::object &python_obj) : Primitive(name), python_obj_(python_obj) {} ~PrimitivePy() override = default; MS_DECLARE_PARENT(PrimitivePy, Primitive); py::function GetBpropFunction() override; py::function GetComputeFunction() override; - void AddPyAttr(const py::str& name, const py::object& obj); + void AddPyAttr(const py::str &name, const py::object &obj); py::dict GetAttrDict(); const bool parse_info_ = true; - const py::object& GetPyObj() const { return python_obj_; } + const py::object &GetPyObj() const { return python_obj_; } bool is_tuple_input_ = false; private: @@ -138,13 +138,13 @@ class PrimitivePy : public Primitive { using PrimitivePyPtr = std::shared_ptr; -inline std::ostream& operator<<(std::ostream& os, const PrimitivePtr& p) { +inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) { os << *p; return os; } struct PrimitiveEqual { - bool operator()(PrimitivePtr const& t1, PrimitivePtr const& t2) const { + bool operator()(PrimitivePtr const &t1, PrimitivePtr const &t2) const { MS_EXCEPTION_IF_NULL(t1); MS_EXCEPTION_IF_NULL(t2); return t1->name() == t2->name(); @@ -152,7 +152,7 @@ struct PrimitiveEqual { }; struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const& prim) const { + std::size_t operator()(PrimitivePtr const &prim) const { std::size_t hash = std::hash()(prim->name()); return hash; } diff --git a/mindspore/ccsrc/ir/scalar.h b/mindspore/ccsrc/ir/scalar.h index 3e0a827b07..ab6c485540 100644 --- a/mindspore/ccsrc/ir/scalar.h +++ b/mindspore/ccsrc/ir/scalar.h @@ -55,8 +55,8 @@ class BoolImm : public Scalar { bool value() const { return v_; } bool IsZero() override { return v_ == false; } bool IsOne() override { return v_ == true; } - bool operator==(const Value& other) const override; - bool operator==(const BoolImm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const BoolImm &other) const; std::string ToString() const override { if (v_) { return "true"; @@ -80,7 +80,7 @@ IMM_TRAITS(BoolImmPtr, bool) class IntergerImm : public Scalar { public: IntergerImm() = default; - explicit IntergerImm(const TypePtr& t) : Scalar(t) {} + explicit IntergerImm(const TypePtr &t) : Scalar(t) {} ~IntergerImm() override = default; MS_DECLARE_PARENT(IntergerImm, Scalar) }; @@ -95,8 +95,8 @@ class Int8Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int8_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int8Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -121,8 +121,8 @@ class Int16Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int16_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int16Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -147,8 +147,8 @@ class Int32Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int32_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -173,8 +173,8 @@ class Int64Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } int64_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const Int64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const Int64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -199,8 +199,8 @@ class UInt8Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint8_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt8Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt8Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -225,8 +225,8 @@ class UInt16Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint16_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt16Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt16Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -251,8 +251,8 @@ class UInt32Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint32_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -277,8 +277,8 @@ class UInt64Imm : public IntergerImm { bool IsZero() override { return v_ == 0; } bool IsOne() override { return v_ == 1; } uint64_t value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const UInt64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const UInt64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -296,7 +296,7 @@ IMM_TRAITS(UInt64ImmPtr, uint64_t); class FloatImm : public Scalar { public: FloatImm() = default; - explicit FloatImm(const TypePtr& t) : Scalar(t) {} + explicit FloatImm(const TypePtr &t) : Scalar(t) {} ~FloatImm() override = default; MS_DECLARE_PARENT(FloatImm, Scalar) }; @@ -312,8 +312,8 @@ class FP32Imm : public FloatImm { bool IsZero() override { return fabs(v_) <= FLT_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= FLT_EPSILON; } float value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const FP32Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const FP32Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { @@ -338,8 +338,8 @@ class FP64Imm : public FloatImm { bool IsZero() override { return fabs(v_) <= DBL_EPSILON; } bool IsOne() override { return fabs(v_ - 1.0) <= DBL_EPSILON; } double value() const { return v_; } - bool operator==(const Value& other) const override; - bool operator==(const FP64Imm& other) const; + bool operator==(const Value &other) const override; + bool operator==(const FP64Imm &other) const; std::string ToString() const override { return std::to_string(v_); } std::string DumpText() const override { diff --git a/mindspore/ccsrc/ir/signature.cc b/mindspore/ccsrc/ir/signature.cc index b7eec921d4..8f312d5b98 100644 --- a/mindspore/ccsrc/ir/signature.cc +++ b/mindspore/ccsrc/ir/signature.cc @@ -21,8 +21,8 @@ #include "pipeline/parse/data_converter.h" namespace mindspore { -Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, - const py::object& arg_default, const SignatureEnumDType& arg_dtype) +Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, + const py::object &arg_default, const SignatureEnumDType &arg_dtype) : name(arg_name), rw(rw_tag), kind(arg_kind), dtype(arg_dtype) { if (py::isinstance(arg_default) && py::cast(arg_default) == SignatureEnumKind::kKindEmptyDefaultValue) { @@ -32,14 +32,14 @@ Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, } } -Signature::Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind) +Signature::Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind) : name(arg_name), rw(rw_tag), kind(arg_kind), default_value(nullptr), dtype(SignatureEnumDType::kDTypeEmptyDefaultValue) {} -REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(SignatureEnumRW, ([](const py::module *m) { (void)py::enum_(*m, "signature_rw", py::arithmetic()) .value("RW_READ", SignatureEnumRW::kRWRead) .value("RW_WRITE", SignatureEnumRW::kRWWrite) diff --git a/mindspore/ccsrc/ir/signature.h b/mindspore/ccsrc/ir/signature.h index 8e7409ab26..48be7e0f31 100644 --- a/mindspore/ccsrc/ir/signature.h +++ b/mindspore/ccsrc/ir/signature.h @@ -61,9 +61,9 @@ struct Signature { SignatureEnumKind kind; ValuePtr default_value; // nullptr for no default value SignatureEnumDType dtype; - Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind, - const py::object& arg_default, const SignatureEnumDType& arg_dtype); - Signature(const std::string& arg_name, const SignatureEnumRW& rw_tag, const SignatureEnumKind& arg_kind); + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind, + const py::object &arg_default, const SignatureEnumDType &arg_dtype); + Signature(const std::string &arg_name, const SignatureEnumRW &rw_tag, const SignatureEnumKind &arg_kind); }; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value.cc b/mindspore/ccsrc/ir/value.cc index f9e8abaee9..e386e1ffd2 100644 --- a/mindspore/ccsrc/ir/value.cc +++ b/mindspore/ccsrc/ir/value.cc @@ -24,7 +24,7 @@ #include "pipeline/static_analysis/abstract_value.h" namespace mindspore { -const ValuePtr ValueSequeue::operator[](const std::size_t& dim) const { +const ValuePtr ValueSequeue::operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "List index [" << dim << "] is out of range [" << size() << "]."; } @@ -40,125 +40,125 @@ bool ValueSequeue::erase(size_t idx) { } } -bool BoolImm::operator==(const Value& other) const { +bool BoolImm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool BoolImm::operator==(const BoolImm& other) const { return v_ == other.v_; } +bool BoolImm::operator==(const BoolImm &other) const { return v_ == other.v_; } -bool Int8Imm::operator==(const Value& other) const { +bool Int8Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int8Imm::operator==(const Int8Imm& other) const { return v_ == other.v_; } -bool Int16Imm::operator==(const Value& other) const { +bool Int8Imm::operator==(const Int8Imm &other) const { return v_ == other.v_; } +bool Int16Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int16Imm::operator==(const Int16Imm& other) const { return v_ == other.v_; } -bool Int32Imm::operator==(const Value& other) const { +bool Int16Imm::operator==(const Int16Imm &other) const { return v_ == other.v_; } +bool Int32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int32Imm::operator==(const Int32Imm& other) const { return v_ == other.v_; } -bool Int64Imm::operator==(const Value& other) const { +bool Int32Imm::operator==(const Int32Imm &other) const { return v_ == other.v_; } +bool Int64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool Int64Imm::operator==(const Int64Imm& other) const { return v_ == other.v_; } -bool UInt8Imm::operator==(const Value& other) const { +bool Int64Imm::operator==(const Int64Imm &other) const { return v_ == other.v_; } +bool UInt8Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt8Imm::operator==(const UInt8Imm& other) const { return v_ == other.v_; } -bool UInt16Imm::operator==(const Value& other) const { +bool UInt8Imm::operator==(const UInt8Imm &other) const { return v_ == other.v_; } +bool UInt16Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt16Imm::operator==(const UInt16Imm& other) const { return v_ == other.v_; } -bool UInt32Imm::operator==(const Value& other) const { +bool UInt16Imm::operator==(const UInt16Imm &other) const { return v_ == other.v_; } +bool UInt32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt32Imm::operator==(const UInt32Imm& other) const { return v_ == other.v_; } -bool UInt64Imm::operator==(const Value& other) const { +bool UInt32Imm::operator==(const UInt32Imm &other) const { return v_ == other.v_; } +bool UInt64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool UInt64Imm::operator==(const UInt64Imm& other) const { return v_ == other.v_; } -bool FP32Imm::operator==(const Value& other) const { +bool UInt64Imm::operator==(const UInt64Imm &other) const { return v_ == other.v_; } +bool FP32Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool FP32Imm::operator==(const FP32Imm& other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } -bool FP64Imm::operator==(const Value& other) const { +bool FP32Imm::operator==(const FP32Imm &other) const { return fabs(v_ - other.v_) < FLT_EPSILON; } +bool FP64Imm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSequeue::operator==(const Value& other) const { +bool ValueSequeue::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSequeue::operator==(const ValueSequeue& other) const { +bool ValueSequeue::operator==(const ValueSequeue &other) const { if (other.elements_.size() != elements_.size()) { return false; } return std::equal(elements_.begin(), elements_.end(), other.elements_.begin(), - [](const ValuePtr& lhs, const ValuePtr& rhs) { return *lhs == *rhs; }); + [](const ValuePtr &lhs, const ValuePtr &rhs) { return *lhs == *rhs; }); } std::string ValueSequeue::ToString() const { std::ostringstream buffer; bool begin = true; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -179,28 +179,28 @@ std::string ValueSequeue::DumpText() const { return oss.str(); } -bool FP64Imm::operator==(const FP64Imm& other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } -bool StringImm::operator==(const Value& other) const { +bool FP64Imm::operator==(const FP64Imm &other) const { return fabs(v_ - other.v_) < DBL_EPSILON; } +bool StringImm::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool StringImm::operator==(const StringImm& other) const { return str_ == other.str_; } +bool StringImm::operator==(const StringImm &other) const { return str_ == other.str_; } -bool RefKey::operator==(const Value& other) const { +bool RefKey::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool RefKey::operator==(const RefKey& other) const { return tag_ == other.tag_; } +bool RefKey::operator==(const RefKey &other) const { return tag_ == other.tag_; } -bool AnyValue::operator==(const Value& other) const { +bool AnyValue::operator==(const Value &other) const { if (other.isa()) { return true; } else { @@ -228,7 +228,7 @@ abstract::AbstractBasePtr AnyValue::ToAbstract() { return std::make_sharedToAbstract(); }); @@ -237,7 +237,7 @@ abstract::AbstractBasePtr ValueTuple::ToAbstract() { abstract::AbstractBasePtr ValueList::ToAbstract() { abstract::AbstractBasePtrList a_list; - (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr& ele) { + (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(a_list), [](const ValuePtr &ele) { MS_EXCEPTION_IF_NULL(ele); return ele->ToAbstract(); }); @@ -251,16 +251,16 @@ std::size_t ValueSlice::hash() const { return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()}); } -bool ValueSlice::operator==(const Value& other) const { +bool ValueSlice::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueSlice::operator==(const ValueSlice& other) const { +bool ValueSlice::operator==(const ValueSlice &other) const { MS_EXCEPTION_IF_NULL(start_); MS_EXCEPTION_IF_NULL(stop_); MS_EXCEPTION_IF_NULL(step_); @@ -295,16 +295,16 @@ std::size_t KeywordArg::hash() const { return hash_combine({tid(), std::hash{}(key_), value_->hash()}); } -bool KeywordArg::operator==(const Value& other) const { +bool KeywordArg::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool KeywordArg::operator==(const KeywordArg& other) const { return (other.key_ == key_ && *other.value_ == *value_); } +bool KeywordArg::operator==(const KeywordArg &other) const { return (other.key_ == key_ && *other.value_ == *value_); } std::string KeywordArg::ToString() const { std::ostringstream buffer; @@ -322,25 +322,25 @@ abstract::AbstractBasePtr KeywordArg::ToAbstract() { return std::make_shared(key_, argument); } -const ValuePtr ValueDictionary::operator[](const std::string& key) const { +const ValuePtr ValueDictionary::operator[](const std::string &key) const { auto it = std::find_if(key_values_.begin(), key_values_.end(), - [key](const std::pair& item) { return item.first == key; }); + [key](const std::pair &item) { return item.first == key; }); if (it == key_values_.end()) { MS_LOG(EXCEPTION) << "The key " << key << " is not in the map"; } return it->second; } -bool ValueDictionary::operator==(const Value& other) const { +bool ValueDictionary::operator==(const Value &other) const { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; } } -bool ValueDictionary::operator==(const ValueDictionary& other) const { +bool ValueDictionary::operator==(const ValueDictionary &other) const { if (key_values_.size() != other.key_values_.size()) { return false; } @@ -359,12 +359,12 @@ abstract::AbstractBasePtr ValueDictionary::ToAbstract() { std::vector> kv; (void)std::transform( key_values_.begin(), key_values_.end(), std::back_inserter(kv), - [](const std::pair& item) { return std::make_pair(item.first, item.second->ToAbstract()); }); + [](const std::pair &item) { return std::make_pair(item.first, item.second->ToAbstract()); }); return std::make_shared(kv); } REGISTER_PYBIND_DEFINE( - RefKey, ([](const py::module* m) { + RefKey, ([](const py::module *m) { (void)py::class_>(*m, "RefKey").def(py::init(), py::arg("tag")); })); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/value.h b/mindspore/ccsrc/ir/value.h index 85f514b57b..c80e22f735 100644 --- a/mindspore/ccsrc/ir/value.h +++ b/mindspore/ccsrc/ir/value.h @@ -35,19 +35,19 @@ namespace mindspore { class ValueSequeue : public Value { public: - explicit ValueSequeue(const ValuePtrList& elements) : elements_(elements) { + explicit ValueSequeue(const ValuePtrList &elements) : elements_(elements) { TypePtrList t_list; - (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr& ele) { + (void)std::transform(elements.begin(), elements.end(), std::back_inserter(t_list), [](const ValuePtr &ele) { MS_EXCEPTION_IF_NULL(ele); return ele->type(); }); TypePtr t = std::make_shared(t_list); type_ = t; } - ValueSequeue(const std::initializer_list& elements) : elements_(elements.begin(), elements.end()) { + ValueSequeue(const std::initializer_list &elements) : elements_(elements.begin(), elements.end()) { TypePtrList t_list; (void)std::transform(elements_.begin(), elements_.end(), std::back_inserter(t_list), - [](const ValuePtr& ele) { return ele->type(); }); + [](const ValuePtr &ele) { return ele->type(); }); TypePtr t = std::make_shared(t_list); type_ = t; } @@ -56,10 +56,10 @@ class ValueSequeue : public Value { std::size_t hash() const override { return hash_combine(tid(), std::hash{}(elements_.size())); } std::size_t size() const { return elements_.size(); } bool erase(size_t idx); - const ValuePtr operator[](const std::size_t& dim) const; - const ValuePtrList& value() const { return elements_; } - bool operator==(const Value& other) const override; - bool operator==(const ValueSequeue& other) const; + const ValuePtr operator[](const std::size_t &dim) const; + const ValuePtrList &value() const { return elements_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueSequeue &other) const; std::string ToString() const override; std::string DumpText() const override; @@ -70,8 +70,8 @@ using ValueSequeuePtr = std::shared_ptr; class ValueTuple : public ValueSequeue { public: - explicit ValueTuple(const std::vector& elements) : ValueSequeue(elements) {} - ValueTuple(const std::initializer_list& elements) : ValueSequeue(elements) {} + explicit ValueTuple(const std::vector &elements) : ValueSequeue(elements) {} + ValueTuple(const std::initializer_list &elements) : ValueSequeue(elements) {} ~ValueTuple() override = default; MS_DECLARE_PARENT(ValueTuple, ValueSequeue) abstract::AbstractBasePtr ToAbstract() override; @@ -83,8 +83,8 @@ using ValueTuplePtr = std::shared_ptr; class ValueList : public ValueSequeue { public: - explicit ValueList(const std::vector& elements) : ValueSequeue(elements) {} - ValueList(const std::initializer_list& elements) : ValueSequeue(elements) {} + explicit ValueList(const std::vector &elements) : ValueSequeue(elements) {} + ValueList(const std::initializer_list &elements) : ValueSequeue(elements) {} ~ValueList() override = default; MS_DECLARE_PARENT(ValueList, ValueSequeue) abstract::AbstractBasePtr ToAbstract() override; @@ -94,7 +94,7 @@ class ValueList : public ValueSequeue { }; using ValueListPtr = std::shared_ptr; -inline ValuePtr MakeValue(const std::vector& v) { return std::make_shared(v); } +inline ValuePtr MakeValue(const std::vector &v) { return std::make_shared(v); } inline ValuePtr MakeValue(std::initializer_list v) { return std::make_shared(v); } template @@ -103,7 +103,7 @@ template struct is_vector> : public std::true_type {}; template ::value, typename T::value_type>::type> -ValuePtr MakeValue(const T& vec) { +ValuePtr MakeValue(const T &vec) { std::vector list; (void)std::transform(vec.begin(), vec.end(), std::back_inserter(list), [](U ele) { return MakeValue(ele); }); return std::make_shared(list); @@ -111,13 +111,13 @@ ValuePtr MakeValue(const T& vec) { class ValueSlice : public Value { public: - ValueSlice(const ValuePtr& start, const ValuePtr& stop, const ValuePtr& step) + ValueSlice(const ValuePtr &start, const ValuePtr &stop, const ValuePtr &step) : start_(start), stop_(stop), step_(step) {} ~ValueSlice() override = default; MS_DECLARE_PARENT(ValueSlice, Value) std::size_t hash() const override; - bool operator==(const Value& other) const override; - bool operator==(const ValueSlice& other) const; + bool operator==(const Value &other) const override; + bool operator==(const ValueSlice &other) const; std::string ToString() const override; @@ -133,13 +133,13 @@ using ValueSlicePtr = std::shared_ptr; class KeywordArg : public Value { public: - KeywordArg(const std::string& key, const ValuePtr& value) : key_(key), value_(value) {} + KeywordArg(const std::string &key, const ValuePtr &value) : key_(key), value_(value) {} ~KeywordArg() override = default; MS_DECLARE_PARENT(KeywordArg, Value) std::size_t hash() const override; ValuePtr get_value() const { return value_; } - bool operator==(const Value& other) const override; - bool operator==(const KeywordArg& other) const; + bool operator==(const Value &other) const override; + bool operator==(const KeywordArg &other) const; std::string ToString() const override; @@ -154,31 +154,31 @@ using KeywordArgPtr = std::shared_ptr; class ValueDictionary : public Value { public: - explicit ValueDictionary(const std::vector>& key_values) : key_values_(key_values) {} + explicit ValueDictionary(const std::vector> &key_values) : key_values_(key_values) {} ~ValueDictionary() override = default; MS_DECLARE_PARENT(ValueDictionary, Value) std::size_t hash() const override { return hash_combine(tid(), std::hash{}(key_values_.size())); } std::size_t size() const { return key_values_.size(); } - const ValuePtr operator[](const std::string& key) const; - const std::vector>& value() const { return key_values_; } - bool operator==(const Value& other) const override; - bool operator==(const ValueDictionary& other) const; + const ValuePtr operator[](const std::string &key) const; + const std::vector> &value() const { return key_values_; } + bool operator==(const Value &other) const override; + bool operator==(const ValueDictionary &other) const; std::string ToString() const override { std::ostringstream buffer; std::vector keys; std::vector values; - for (const auto& kv : key_values_) { + for (const auto &kv : key_values_) { keys.push_back(kv.first); values.push_back(kv.second); } buffer << "(Dict: " << " keys:("; - for (const auto& key : keys) { + for (const auto &key : keys) { buffer << key << ", "; } buffer << ") values:("; - for (const auto& value : values) { + for (const auto &value : values) { MS_EXCEPTION_IF_NULL(value); buffer << value->DumpText() << ", "; } @@ -195,14 +195,14 @@ using ValueDictionaryPtr = std::shared_ptr; class StringImm : public Value { public: - explicit StringImm(const std::string& str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} + explicit StringImm(const std::string &str) : Value(kString), str_(str), hash_(std::hash{}(str_)) {} ~StringImm() override = default; MS_DECLARE_PARENT(StringImm, Value) std::size_t hash() const override { return hash_; } - const std::string& value() const { return str_; } - bool operator==(const Value& other) const override; - bool operator==(const StringImm& other) const; + const std::string &value() const { return str_; } + bool operator==(const Value &other) const override; + bool operator==(const StringImm &other) const; abstract::AbstractBasePtr ToAbstract() override; std::string ToString() const override { return str_; } @@ -218,18 +218,18 @@ class StringImm : public Value { }; using StringImmPtr = std::shared_ptr; IMM_TRAITS(StringImmPtr, std::string) -IMM_TRAITS(StringImmPtr, const char*) +IMM_TRAITS(StringImmPtr, const char *) class RefKey : public Value { public: - explicit RefKey(const std::string& tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} + explicit RefKey(const std::string &tag) : Value(kRefKeyType), tag_(tag), hash_(std::hash{}(tag)) {} ~RefKey() override = default; MS_DECLARE_PARENT(RefKey, Value) std::size_t hash() const override { return hash_; } - const std::string& tag() const { return tag_; } - bool operator==(const Value& other) const override; - bool operator==(const RefKey& other) const; + const std::string &tag() const { return tag_; } + bool operator==(const Value &other) const override; + bool operator==(const RefKey &other) const; abstract::AbstractBasePtr ToAbstract() override; std::string ToString() const override { return "RefKey[" + tag_ + "]"; } @@ -251,13 +251,13 @@ class AnyValue : public Value { ~AnyValue() override = default; MS_DECLARE_PARENT(AnyValue, Value) std::size_t hash() const override { return tid(); } - bool operator==(const Value& other) const override; + bool operator==(const Value &other) const override; abstract::AbstractBasePtr ToAbstract() override; }; extern const ValuePtr kAnyValue; template <> -inline const char* GetValue(const ValuePtr& value) { +inline const char *GetValue(const ValuePtr &value) { if (value == nullptr) { MS_LOG(EXCEPTION) << "Value is nullptr"; } @@ -270,7 +270,7 @@ inline const char* GetValue(const ValuePtr& value) { template ::type, typename U = typename std::enable_if::value, typename S::value_type>::type> -std::vector GetValue(const ValuePtr& value) { +std::vector GetValue(const ValuePtr &value) { if (value == nullptr) { MS_LOG(EXCEPTION) << "Value is nullptr"; } @@ -280,21 +280,21 @@ std::vector GetValue(const ValuePtr& value) { << ">"; } std::vector rets; - const std::vector& vals = value->cast()->value(); + const std::vector &vals = value->cast()->value(); (void)std::transform(vals.begin(), vals.end(), std::back_inserter(rets), - [](const ValuePtr& v) { return GetValue(v); }); + [](const ValuePtr &v) { return GetValue(v); }); return rets; } -inline ValueNodePtr NewValueNode(const ValuePtr& t) { return std::make_shared(t); } +inline ValueNodePtr NewValueNode(const ValuePtr &t) { return std::make_shared(t); } template ::value>::type> -inline ValueNodePtr NewValueNode(const std::shared_ptr& x) { +inline ValueNodePtr NewValueNode(const std::shared_ptr &x) { return NewValueNode(MakeValue(x)); } template ::value>::type> -inline ValueNodePtr NewValueNode(const T& x) { +inline ValueNodePtr NewValueNode(const T &x) { return NewValueNode(MakeValue(x)); } } // namespace mindspore diff --git a/mindspore/ccsrc/ir/visitor.h b/mindspore/ccsrc/ir/visitor.h index 5305d1fe85..e771f7ad28 100644 --- a/mindspore/ccsrc/ir/visitor.h +++ b/mindspore/ccsrc/ir/visitor.h @@ -22,15 +22,15 @@ #include "optimizer/opt.h" namespace mindspore { -using VisitFuncType = std::function; +using VisitFuncType = std::function; class AnfVisitor { public: - virtual AnfNodePtr operator()(const opt::OptimizerPtr&, const AnfNodePtr&); - virtual void Visit(const AnfNodePtr&); - virtual void Visit(const CNodePtr&); - virtual void Visit(const ValueNodePtr&); - virtual void Visit(const ParameterPtr&); - VisitFuncType Match(const PrimitivePtr&, const std::vector& = {}); + virtual AnfNodePtr operator()(const opt::OptimizerPtr &, const AnfNodePtr &); + virtual void Visit(const AnfNodePtr &); + virtual void Visit(const CNodePtr &); + virtual void Visit(const ValueNodePtr &); + virtual void Visit(const ParameterPtr &); + VisitFuncType Match(const PrimitivePtr &, const std::vector & = {}); virtual ~AnfVisitor() = default; }; } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/kernel_query.cc b/mindspore/ccsrc/kernel/kernel_query.cc index 7934bd0a5c..3d3282e7b5 100755 --- a/mindspore/ccsrc/kernel/kernel_query.cc +++ b/mindspore/ccsrc/kernel/kernel_query.cc @@ -26,12 +26,12 @@ namespace mindspore { namespace kernel { namespace { -void FilterInvaildKernelInfo(const CNodePtr& kernel_node, - std::vector>* kernel_info_list) { +void FilterInvaildKernelInfo(const CNodePtr &kernel_node, + std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_info_list); std::vector> filtered_list; (void)std::copy_if(kernel_info_list->begin(), kernel_info_list->end(), std::back_inserter(filtered_list), - [&](const std::shared_ptr& kernel_build_info) { + [&](const std::shared_ptr &kernel_build_info) { return AnfAlgo::GetOutputTensorNum(kernel_node) == kernel_build_info->GetOutputNum() && AnfAlgo::GetInputTensorNum(kernel_node) == kernel_build_info->GetInputNum(); }); @@ -46,7 +46,7 @@ void FilterInvaildKernelInfo(const CNodePtr& kernel_node, } } } // namespace -void KernelQuery(const CNodePtr& kernel_node, std::vector>* kernel_info_list) { +void KernelQuery(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_info_list); TbeMetadataInfo(kernel_node, kernel_info_list); diff --git a/mindspore/ccsrc/kernel/oplib/opinfo.h b/mindspore/ccsrc/kernel/oplib/opinfo.h index 215df21776..670830a8b1 100644 --- a/mindspore/ccsrc/kernel/oplib/opinfo.h +++ b/mindspore/ccsrc/kernel/oplib/opinfo.h @@ -38,11 +38,11 @@ class OpAttr { std::string value() const { return value_; } std::string default_value() const { return default_value_; } - void set_name(const std::string& name) { name_ = name; } - void set_param_type(const std::string& param_type) { param_type_ = param_type; } - void set_type(const std::string& type) { type_ = type; } - void set_value(const std::string& value) { value_ = value; } - void set_default_value(const std::string& default_value) { default_value_ = default_value; } + void set_name(const std::string &name) { name_ = name; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_type(const std::string &type) { type_ = type; } + void set_value(const std::string &value) { value_ = value; } + void set_default_value(const std::string &default_value) { default_value_ = default_value; } private: std::string name_; @@ -67,13 +67,13 @@ class OpIOInfo { std::vector formats() const { return formats_; } void set_index(const int index) { index_ = index; } - void set_name(const std::string& name) { name_ = name; } + void set_name(const std::string &name) { name_ = name; } void set_need_compile(const bool need_compile) { need_compile_ = need_compile; } - void set_param_type(const std::string& param_type) { param_type_ = param_type; } - void set_reshape_type(const std::string& reshape_type) { reshape_type_ = reshape_type; } - void set_shape(const std::string& shape) { shape_ = shape; } - void set_dtypes(const std::vector& dtype) { dtypes_ = dtype; } - void set_formats(const std::vector& formats) { formats_ = formats; } + void set_param_type(const std::string ¶m_type) { param_type_ = param_type; } + void set_reshape_type(const std::string &reshape_type) { reshape_type_ = reshape_type; } + void set_shape(const std::string &shape) { shape_ = shape; } + void set_dtypes(const std::vector &dtype) { dtypes_ = dtype; } + void set_formats(const std::vector &formats) { formats_ = formats; } private: int index_ = 0; @@ -104,24 +104,24 @@ class OpInfo { std::vector> attrs_ptr() const { return attrs_ptr_; } std::vector> inputs_ptr() const { return inputs_ptr_; } std::vector> outputs_ptr() const { return outputs_ptr_; } - const std::unordered_map& ref_infos() const { return ref_infos_; } + const std::unordered_map &ref_infos() const { return ref_infos_; } - void set_op_name(const std::string& op_name) { op_name_ = op_name; } + void set_op_name(const std::string &op_name) { op_name_ = op_name; } void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; } - void set_impl_path(const std::string& impl_path) { impl_path_ = impl_path; } - void set_fusion_type(const std::string& fusion_type) { fusion_type_ = fusion_type; } + void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; } + void set_fusion_type(const std::string &fusion_type) { fusion_type_ = fusion_type; } void set_async_flag(const bool async_flag) { async_flag_ = async_flag; } - void set_binfile_name(const std::string& binfile_name) { binfile_name_ = binfile_name; } + void set_binfile_name(const std::string &binfile_name) { binfile_name_ = binfile_name; } void set_compute_cost(const int compute_cost) { compute_cost_ = compute_cost; } - void set_kernel_name(const std::string& kernel_name) { kernel_name_ = kernel_name; } + void set_kernel_name(const std::string &kernel_name) { kernel_name_ = kernel_name; } void set_partial_flag(const bool partial_flag) { partial_flag_ = partial_flag; } void set_dynamic_format(const bool dynamic_format) { dynamic_format_ = dynamic_format; } void set_op_pattern(const std::string op_pattern) { op_pattern_ = op_pattern; } - void add_attrs_ptr(const std::shared_ptr& attr) { attrs_ptr_.push_back(attr); } - void add_inputs_ptr(const std::shared_ptr& input) { inputs_ptr_.push_back(input); } - void add_outputs_ptr(const std::shared_ptr& output) { outputs_ptr_.push_back(output); } - void set_inputs_ptr(const std::vector>& inputs) { inputs_ptr_ = inputs; } - void set_outputs_ptr(const std::vector>& outputs) { outputs_ptr_ = outputs; } + void add_attrs_ptr(const std::shared_ptr &attr) { attrs_ptr_.push_back(attr); } + void add_inputs_ptr(const std::shared_ptr &input) { inputs_ptr_.push_back(input); } + void add_outputs_ptr(const std::shared_ptr &output) { outputs_ptr_.push_back(output); } + void set_inputs_ptr(const std::vector> &inputs) { inputs_ptr_ = inputs; } + void set_outputs_ptr(const std::vector> &outputs) { outputs_ptr_ = outputs; } bool is_ref() const { return !ref_infos_.empty(); } bool has_ref_index(size_t out_index) const { return ref_infos_.find(out_index) != ref_infos_.end(); } void add_ref_pair(size_t out_index, size_t in_index) { (void)ref_infos_.emplace(out_index, in_index); } diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index c8cc1530ce..cd0f843867 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -67,7 +67,7 @@ std::string ImplTypeToStr(OpImplyType impl_type) { return "unknow"; } } -bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) { +bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) { bool ret = false; try { auto op_json = nlohmann::json::parse(json_string); @@ -88,13 +88,13 @@ bool OpLib::RegOp(const std::string& json_string, const std::string& impl_path) if (!ret) { MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); } return ret; } -void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info) { +void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info) { op_info->set_async_flag(obj.at(kAsyncFlag)); op_info->set_binfile_name(obj.at(kBinfileName)); op_info->set_compute_cost(obj.at(kComputeCost)); @@ -108,8 +108,8 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_p } } -bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpImplyType imply_type, - const std::string& impl_path) { +bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpImplyType imply_type, + const std::string &impl_path) { std::shared_ptr op_info = std::make_shared(); MS_EXCEPTION_IF_NULL(op_info); op_info->set_op_name(obj.at(kOpName)); @@ -120,7 +120,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI DecodeTBESpecificInfo(obj, op_info); } auto attrs = obj.at(kAttr); - for (const auto& attr : attrs) { + for (const auto &attr : attrs) { if (!DecodeAttr(attr, imply_type, op_info)) { MS_LOG(DEBUG) << "DecodeAttr Failed"; return false; @@ -131,14 +131,14 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI dtype_format = obj.at(kDtypeFormat); } auto inputs = obj.at(kIputs); - for (const auto& input : inputs) { + for (const auto &input : inputs) { if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { MS_LOG(DEBUG) << "DecodeInputOutput Failed"; return false; } } auto outputs = obj.at(kOutputs); - for (const auto& output : outputs) { + for (const auto &output : outputs) { if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { MS_LOG(DEBUG) << "DecodeInputOutput Failed"; return false; @@ -156,8 +156,8 @@ bool OpLib::DecodeOpInfo(const nlohmann::json& obj, const mindspore::kernel::OpI return true; } -bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, - const std::shared_ptr& op_info) { +bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); bool ret = true; try { @@ -175,34 +175,34 @@ bool OpLib::DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, op_attr->set_default_value(obj.at(kDefaultValue)); } op_info->add_attrs_ptr(op_attr); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); ret = false; } return ret; } -bool OpLib::DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, +bool OpLib::DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, size_t index) { bool ret = true; try { std::vector dtype; std::vector format; - for (const auto& it : dtype_format) { + for (const auto &it : dtype_format) { dtype.emplace_back(it[index][0]); format.emplace_back(it[index][1]); } op_io->set_dtypes(dtype); op_io->set_formats(format); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "DecodeDtypeFormat falied" << e.what(); ret = false; } return ret; } -bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info, const nlohmann::json& dtype_format) { +bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format) { bool ret = true; try { std::shared_ptr op_io = std::make_shared(); @@ -243,14 +243,14 @@ bool OpLib::DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply } else if (io_type == kOutput) { op_info->add_outputs_ptr(op_io); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); ret = false; } return ret; } -std::shared_ptr OpLib::FindOp(const std::string& op_name, OpImplyType imply_type) { +std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) { auto context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context); bool is_gpu = (context->device_target() == kGPUDevice); @@ -260,7 +260,7 @@ std::shared_ptr OpLib::FindOp(const std::string& op_name, OpImplyType im << ", current op num:" << op_info_.size(); return nullptr; } - for (const auto& op_info : op_info_) { + for (const auto &op_info : op_info_) { MS_EXCEPTION_IF_NULL(op_info); if (op_info->op_name() == op_name && op_info->imply_type() == imply_type) { return op_info; @@ -271,14 +271,14 @@ std::shared_ptr OpLib::FindOp(const std::string& op_name, OpImplyType im return nullptr; } -bool OpLib::GetRefInfo(const std::shared_ptr& op_info) { +bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); - const auto& output_infos = op_info->outputs_ptr(); - const auto& input_infos = op_info->inputs_ptr(); + const auto &output_infos = op_info->outputs_ptr(); + const auto &input_infos = op_info->inputs_ptr(); for (size_t out_index = 0; out_index < output_infos.size(); out_index++) { - const auto& out_name = output_infos[out_index]->name(); + const auto &out_name = output_infos[out_index]->name(); for (size_t in_index = 0; in_index < input_infos.size(); in_index++) { - const auto& in_name = input_infos[in_index]->name(); + const auto &in_name = input_infos[in_index]->name(); if (out_name == in_name) { if (op_info->has_ref_index(out_index)) { MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; @@ -293,9 +293,9 @@ bool OpLib::GetRefInfo(const std::shared_ptr& op_info) { return true; } -bool OpLib::CheckRepetition(const std::shared_ptr& op_info) { +bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(op_info); - for (const auto& exist_op_info : op_info_) { + for (const auto &exist_op_info : op_info_) { MS_EXCEPTION_IF_NULL(exist_op_info); if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && exist_op_info->impl_path() != op_info->impl_path()) { diff --git a/mindspore/ccsrc/kernel/oplib/oplib.h b/mindspore/ccsrc/kernel/oplib/oplib.h index 0e11e28d58..3d4dcad908 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.h +++ b/mindspore/ccsrc/kernel/oplib/oplib.h @@ -28,23 +28,23 @@ class OpLib { public: OpLib() = default; virtual ~OpLib() = default; - bool RegOp(const std::string& json_string, const std::string& impl_path); - static std::shared_ptr FindOp(const std::string& op_name, OpImplyType imply_type); + bool RegOp(const std::string &json_string, const std::string &impl_path); + static std::shared_ptr FindOp(const std::string &op_name, OpImplyType imply_type); protected: static std::vector> op_info_; private: - static bool DecodeOpInfo(const nlohmann::json& obj, const OpImplyType imply_type, const std::string& impl_path); - static bool DecodeAttr(const nlohmann::json& obj, const OpImplyType imply_type, - const std::shared_ptr& op_info); - static bool DecodeDtypeFormat(const nlohmann::json& dtype_format, const std::shared_ptr& op_io, + static bool DecodeOpInfo(const nlohmann::json &obj, const OpImplyType imply_type, const std::string &impl_path); + static bool DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, + const std::shared_ptr &op_info); + static bool DecodeDtypeFormat(const nlohmann::json &dtype_format, const std::shared_ptr &op_io, size_t index); - static void DecodeTBESpecificInfo(const nlohmann::json& obj, const std::shared_ptr& op_info); - static bool DecodeInputOutput(const nlohmann::json& obj, const OpImplyType imply_type, const OpIOType io_type, - const std::shared_ptr& op_info, const nlohmann::json& dtype_format); - static bool GetRefInfo(const std::shared_ptr& op_info); - static bool CheckRepetition(const std::shared_ptr& op_info); + static void DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_ptr &op_info); + static bool DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply_type, const OpIOType io_type, + const std::shared_ptr &op_info, const nlohmann::json &dtype_format); + static bool GetRefInfo(const std::shared_ptr &op_info); + static bool CheckRepetition(const std::shared_ptr &op_info); }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/mindspore.cc b/mindspore/ccsrc/mindspore.cc index 542814016f..c98f67b51e 100644 --- a/mindspore/ccsrc/mindspore.cc +++ b/mindspore/ccsrc/mindspore.cc @@ -19,6 +19,6 @@ namespace mindspore { // cppcheck-suppress unusedFunction -std::string set_version(const std::string& version) { return version; } +std::string set_version(const std::string &version) { return version; } } // namespace mindspore diff --git a/mindspore/ccsrc/onnx/onnx_exporter.cc b/mindspore/ccsrc/onnx/onnx_exporter.cc index 80661a4539..772986d714 100644 --- a/mindspore/ccsrc/onnx/onnx_exporter.cc +++ b/mindspore/ccsrc/onnx/onnx_exporter.cc @@ -42,11 +42,11 @@ struct OpMergedInfo { }; using GenAttrFuncType = - std::function; + std::function; template -void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetAttrValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { auto casted_value = dyn_cast(value); if (casted_value == nullptr) { MS_LOG(EXCEPTION) << "Cast value " << value->ToString() << " to type T failed."; @@ -76,8 +76,8 @@ void SetAttrValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeTy } template -void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_AttributeType attr_type, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetAttrTupleValueToProto(const ValuePtr &value, onnx::AttributeProto_AttributeType attr_type, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { auto tuple_ptr = dyn_cast(value); if (tuple_ptr == nullptr) { MS_LOG(EXCEPTION) << "Cast value from type " << value->type_name() << " to ValueTuple failed."; @@ -99,8 +99,8 @@ void SetAttrTupleValueToProto(const ValuePtr& value, onnx::AttributeProto_Attrib attr_proto->set_type(attr_type); } -void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { +void SetPoolingPadMode(const ValuePtr &value, onnx::AttributeProto_AttributeType, + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); auto attr_value = GetValue(value); if (attr_value == "VALID") { @@ -112,16 +112,16 @@ void SetPoolingPadMode(const ValuePtr& value, onnx::AttributeProto_AttributeType class OpAttrInfo { public: - OpAttrInfo(const std::string& attr_name, const string& onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) + OpAttrInfo(const std::string &attr_name, const string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) : attr_name_(attr_name), onnx_attr_name_(onnx_attr_name), onnx_attr_type_(onnx_attr_type), fn_gen_attr_(fn_gen_attr) {} ~OpAttrInfo() {} - const std::string& attr_name() const { return attr_name_; } - const std::string& onnx_attr_name() const { return onnx_attr_name_; } + const std::string &attr_name() const { return attr_name_; } + const std::string &onnx_attr_name() const { return onnx_attr_name_; } onnx::AttributeProto_AttributeType onnx_attr_type() const { return onnx_attr_type_; } GenAttrFuncType fn_gen_attr() const { return fn_gen_attr_; } @@ -134,27 +134,27 @@ class OpAttrInfo { class OpNameInfo { public: - OpNameInfo& set_op_type(const std::string& op_type) { + OpNameInfo &set_op_type(const std::string &op_type) { op_type_ = op_type; return *this; } - const std::string& op_type() const { return op_type_; } + const std::string &op_type() const { return op_type_; } - OpNameInfo& set_onnx_type(const std::string& onnx_type) { + OpNameInfo &set_onnx_type(const std::string &onnx_type) { onnx_type_ = onnx_type; return *this; } - const std::string& onnx_type() const { return onnx_type_; } + const std::string &onnx_type() const { return onnx_type_; } - OpNameInfo& Attr(const std::string& attr_name, const std::string& onnx_attr_name, - onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType& fn_gen_attr) { + OpNameInfo &Attr(const std::string &attr_name, const std::string &onnx_attr_name, + onnx::AttributeProto_AttributeType onnx_attr_type, const GenAttrFuncType &fn_gen_attr) { op_attrs_.emplace_back(OpAttrInfo(attr_name, onnx_attr_name, onnx_attr_type, fn_gen_attr)); return *this; } - const std::vector& op_attrs() const { return op_attrs_; } + const std::vector &op_attrs() const { return op_attrs_; } private: std::string op_type_; // operator type of MindSpore @@ -183,8 +183,8 @@ OPERATOR_ONNX_CONVERT_DEFINE( .Attr("group", "group", onnx::AttributeProto_AttributeType_INT, SetAttrValueToProto) .Attr("kernel_size", "kernel_shape", onnx::AttributeProto_AttributeType_INTS, SetAttrTupleValueToProto<0>) .Attr("pad_mode", "auto_pad", onnx::AttributeProto_AttributeType_STRING, - [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto* const attr_proto, - const PrimitivePtr& prim) { + [](ValuePtr value, onnx::AttributeProto_AttributeType, onnx::AttributeProto *const attr_proto, + const PrimitivePtr &prim) { attr_proto->set_type(onnx::AttributeProto_AttributeType_STRING); auto attr_value = GetValue(value); if (attr_value == "valid") { @@ -220,7 +220,7 @@ OPERATOR_ONNX_CONVERT_DEFINE(Argmax, ArgMax, SetAttrValueToProto) .Attr("", "keepdims", onnx::AttributeProto_AttributeType_INT, [](ValuePtr, onnx::AttributeProto_AttributeType, - onnx::AttributeProto* const attr_proto, const PrimitivePtr&) { + onnx::AttributeProto *const attr_proto, const PrimitivePtr &) { attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); attr_proto->set_i(0); })) @@ -242,7 +242,7 @@ OPERATOR_ONNX_CONVERT_DEFINE( #define OP_CONVERT_FUNCTION_NAME(name) GetOpOnnxConvertInfo_##name -void RegisterOpConverters(const std::function& fn) { +void RegisterOpConverters(const std::function &fn) { fn(OP_CONVERT_FUNCTION_NAME(TensorAdd)()); fn(OP_CONVERT_FUNCTION_NAME(Mul)()); @@ -265,16 +265,16 @@ class OpConvertRegistry { public: ~OpConvertRegistry() { Clear(); } - static void RegisterOneOpConverter(OpNameInfo&& op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } + static void RegisterOneOpConverter(OpNameInfo &&op_info) { GetSingleton().op_map_[op_info.op_type()] = op_info; } static void RegisterAllOpConverters() { RegisterOpConverters(RegisterOneOpConverter); } - static OpConvertRegistry& GetSingleton() { + static OpConvertRegistry &GetSingleton() { static OpConvertRegistry registry = OpConvertRegistry(); return registry; } - static const std::unordered_map& GetOpConvertMap() { return GetSingleton().op_map_; } + static const std::unordered_map &GetOpConvertMap() { return GetSingleton().op_map_; } void Clear() noexcept { op_map_.clear(); } @@ -289,59 +289,59 @@ class OnnxExporter { OnnxExporter() {} ~OnnxExporter() {} - std::string GetOnnxProtoString(const FuncGraphPtr& func_graph); + std::string GetOnnxProtoString(const FuncGraphPtr &func_graph); private: void InitModelInfo(); - void ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); - void ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* graph_proto); + void ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); + void ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *graph_proto); - size_t ExportPrimitive(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - const PrimitivePtr& prim, const std::vector& inputs, - onnx::GraphProto* graph_proto); + size_t ExportPrimitive(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *graph_proto); static onnx::TensorProto_DataType GetOnnxDataType(TypeId type_id); - void SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* value_proto, bool is_output = false); - void SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* tensor_proto); - - void MatchAndMark(const FuncGraphPtr& func_graph, const std::vector& nodes, - std::unordered_map* op_merged_infos_ptr); - void ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - - void ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - - void ExportPrimReshape(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - void ExportPrimReduceMean(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - void ExportPrimCast(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportPrimPReLU(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - - void ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - void ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* graph_proto); - - void ExportOutput(const FuncGraphPtr& func_graph, const CNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* graph_proto); - std::string GetNodeInputName(const AnfNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto); - - void ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* tensor_proto); - void SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* node_proto); + void SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *value_proto, bool is_output = false); + void SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *tensor_proto); + + void MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr); + void ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportPrimReshape(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimReduceMean(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + void ExportPrimCast(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportPrimPReLU(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + + void ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + void ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *graph_proto); + + void ExportOutput(const FuncGraphPtr &func_graph, const CNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *graph_proto); + std::string GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto); + + void ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *tensor_proto); + void SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *node_proto); size_t AllocateNodeIndex() { return ++onnx_node_index_; } void ResetNodeIndex() { onnx_node_index_ = 0; } - static int GetInt32Value(const AnfNodePtr& node) { + static int GetInt32Value(const AnfNodePtr &node) { auto value_node_ptr = dyn_cast(node); MS_EXCEPTION_IF_NULL(value_node_ptr); return GetValue(value_node_ptr->value()); @@ -352,7 +352,7 @@ class OnnxExporter { size_t onnx_node_index_ = 0; }; -std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { +std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { return ""; } @@ -360,7 +360,7 @@ std::string OnnxExporter::GetOnnxProtoString(const FuncGraphPtr& func_graph) { OpConvertRegistry::GetSingleton().Clear(); OpConvertRegistry::RegisterAllOpConverters(); InitModelInfo(); - onnx::GraphProto* graph_proto = model_.mutable_graph(); + onnx::GraphProto *graph_proto = model_.mutable_graph(); ExportFuncGraph(func_graph, graph_proto); return model_.SerializeAsString(); } @@ -369,11 +369,11 @@ void OnnxExporter::InitModelInfo() { model_.set_ir_version(onnx::IR_VERSION_2019_1_22); model_.set_producer_name("MindSpore"); model_.set_producer_version("1.0"); - onnx::OperatorSetIdProto* opset_proto = model_.add_opset_import(); + onnx::OperatorSetIdProto *opset_proto = model_.add_opset_import(); opset_proto->set_version(9); } -void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportFuncGraph(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { std::map node_map; onnx_node_index_ = func_graph->parameters().size(); @@ -390,14 +390,14 @@ void OnnxExporter::ExportFuncGraph(const FuncGraphPtr& func_graph, onnx::GraphPr ExportNodes(func_graph, &node_map, graph_proto); } -void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphProto* const graph_proto) { - for (auto& param : func_graph->parameters()) { +void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphProto *const graph_proto) { + for (auto ¶m : func_graph->parameters()) { const ParameterPtr param_ptr = dyn_cast(param); if (param_ptr == nullptr) { MS_LOG(EXCEPTION) << "Parameter '" << param->ToString() << "' could not cast to parameter."; } - onnx::ValueInfoProto* input_proto = graph_proto->add_input(); + onnx::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_ptr->ToString()); SetValueInfoType(param_ptr, input_proto); @@ -405,7 +405,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr& func_graph, onnx::GraphP continue; } // parameter with default value is an ONNX initializer - onnx::TensorProto* initializer_proto = graph_proto->add_initializer(); + onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); initializer_proto->set_name(param_ptr->ToString()); SetTensorProtoInfo(param_ptr, initializer_proto); // set value for initializer @@ -445,25 +445,25 @@ onnx::TensorProto_DataType OnnxExporter::GetOnnxDataType(TypeId type_id) { return iter->second; } -void OnnxExporter::SetValueInfoType(const AnfNodePtr& node, onnx::ValueInfoProto* const value_proto, bool is_output) { +void OnnxExporter::SetValueInfoType(const AnfNodePtr &node, onnx::ValueInfoProto *const value_proto, bool is_output) { auto dtype = node->Type(); auto shape = node->Shape(); - onnx::TypeProto* type_proto = value_proto->mutable_type(); + onnx::TypeProto *type_proto = value_proto->mutable_type(); if (dtype->isa() && shape->isa()) { auto tensor = dyn_cast(dtype); auto elem_type = tensor->element(); - const auto& dims = dyn_cast(shape)->shape(); + const auto &dims = dyn_cast(shape)->shape(); // output type of 'Argmax' of MindSpore is int32, output type of 'ArgMax' of ONNX is int64 auto type = is_output ? onnx::TensorProto_DataType_INT64 : GetOnnxDataType(elem_type->type_id()); type_proto->mutable_tensor_type()->set_elem_type(type); - for (const auto& dim : dims) { + for (const auto &dim : dims) { type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(dim); } } } -void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorProto* const tensor_proto) { +void OnnxExporter::SetTensorProtoInfo(const ParameterPtr ¶m, onnx::TensorProto *const tensor_proto) { auto dtype = param->Type(); auto shape = param->Shape(); if (!dtype->isa() || !shape->isa()) { @@ -472,18 +472,18 @@ void OnnxExporter::SetTensorProtoInfo(const ParameterPtr& param, onnx::TensorPro auto tensor = dyn_cast(dtype); auto elem_type = tensor->element(); - const auto& dims = dyn_cast(shape)->shape(); + const auto &dims = dyn_cast(shape)->shape(); tensor_proto->set_data_type(GetOnnxDataType(elem_type->type_id())); - for (const auto& dim : dims) { + for (const auto &dim : dims) { tensor_proto->add_dims(dim); } } -void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vector& nodes, - std::unordered_map* op_merged_infos_ptr) { - std::unordered_map& op_merged_infos = *op_merged_infos_ptr; +void OnnxExporter::MatchAndMark(const FuncGraphPtr &func_graph, const std::vector &nodes, + std::unordered_map *op_merged_infos_ptr) { + std::unordered_map &op_merged_infos = *op_merged_infos_ptr; - for (auto& node : nodes) { + for (auto &node : nodes) { if (!node->isa()) { continue; } @@ -492,7 +492,7 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto // if the key `input` does not exist, just create a new one op_merged_infos[cnode].referred_count += 1; } - for (auto& input : cnode->inputs()) { + for (auto &input : cnode->inputs()) { if (!input->isa()) { continue; } @@ -527,14 +527,14 @@ void OnnxExporter::MatchAndMark(const FuncGraphPtr& func_graph, const std::vecto * | +-- Parameter * | `-- ValueNode */ -void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportNodes(const FuncGraphPtr &func_graph, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); std::unordered_map op_merged_infos; MatchAndMark(func_graph, nodes, &op_merged_infos); - for (const AnfNodePtr& node : nodes) { + for (const AnfNodePtr &node : nodes) { if (!node->isa()) { continue; } @@ -570,20 +570,20 @@ void OnnxExporter::ExportNodes(const FuncGraphPtr& func_graph, std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimReshape(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto name_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_shape = node->input(2); std::string name_shape; if (input_shape->isa()) { auto const_node_idx = AllocateNodeIndex(); (*node_map_ptr)[input_shape] = const_node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); name_shape = std::to_string(const_node_idx); node_proto->add_output(name_shape); node_proto->set_op_type("Constant"); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); attr_proto->set_type(onnx::AttributeProto_AttributeType_TENSOR); @@ -595,28 +595,28 @@ void OnnxExporter::ExportPrimReshape(const FuncGraphPtr& /*func_graph*/, const C auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimReshape->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(name_x); node_proto->add_input(name_shape); } -void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_axis = node->input(2); auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimReduceMean->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_data); if (input_axis->isa()) { - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("axes"); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); auto axis_value = dyn_cast(input_axis)->value(); @@ -630,20 +630,20 @@ void OnnxExporter::ExportPrimReduceMean(const FuncGraphPtr& /*func_graph*/, cons } } -void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimCast(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_data = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_type = node->input(2); auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type(prim::kPrimCast->name()); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_data); if (input_type->isa()) { - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("to"); attr_proto->set_type(onnx::AttributeProto_AttributeType_INT); auto type_value = dyn_cast(input_type)->value(); @@ -655,8 +655,8 @@ void OnnxExporter::ExportPrimCast(const FuncGraphPtr& /*func_graph*/, const CNod } } -void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto input_x = GetNodeInputName(node->input(1), node_map_ptr, graph_proto); auto input_slope = GetNodeInputName(node->input(2), node_map_ptr, graph_proto); @@ -668,11 +668,11 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo // format of x is NCHW, input format is NCHW, if length of input_slope is 1, insert Unsqueeze [1,2] if (x_shape->shape().size() == 4 && slope_shape->shape().size() == 1) { auto node_idx = AllocateNodeIndex(); - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type("Unsqueeze"); node_proto->add_output(std::to_string(node_idx)); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_type(onnx::AttributeProto_AttributeType_INTS); attr_proto->set_name("axes"); attr_proto->add_ints(1); @@ -684,15 +684,15 @@ void OnnxExporter::ExportPrimPReLU(const FuncGraphPtr& /*func_graph*/, const CNo auto node_idx = AllocateNodeIndex(); (*node_map_ptr)[node] = node_idx; - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->set_op_type("PRelu"); node_proto->add_output(std::to_string(node_idx)); node_proto->add_input(input_x); node_proto->add_input(input_slope); } -void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportCNode(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { // Type of the 2nd input of 'Reshape' of MindSpore is tuple, but ONNX's is tensor, need to do some convert if (node->IsApply(prim::kPrimReshape)) { return ExportPrimReshape(func_graph, node, node_map_ptr, graph_proto); @@ -735,31 +735,31 @@ void OnnxExporter::ExportCNode(const FuncGraphPtr& func_graph, const CNodePtr& n (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim, op_inputs, graph_proto); } -size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::map* node_map_ptr, - const PrimitivePtr& prim, const std::vector& inputs, - onnx::GraphProto* const graph_proto) { +size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr & /*func_graph*/, std::map *node_map_ptr, + const PrimitivePtr &prim, const std::vector &inputs, + onnx::GraphProto *const graph_proto) { auto op_map = OpConvertRegistry::GetOpConvertMap(); auto op_iter = op_map.find(prim->name()); if (op_iter == op_map.end()) { MS_LOG(EXCEPTION) << "Can not find key " << prim->name() << " in convert map"; } - const OpNameInfo& op_convert_info = op_iter->second; + const OpNameInfo &op_convert_info = op_iter->second; auto node_idx = AllocateNodeIndex(); - onnx::NodeProto* node_proto = graph_proto->add_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->add_output(std::to_string(node_idx)); node_proto->set_op_type(op_convert_info.onnx_type()); // Set inputs - for (const auto& input : inputs) { + for (const auto &input : inputs) { auto input_name = GetNodeInputName(input, node_map_ptr, graph_proto); node_proto->add_input(input_name); } // Set node attribute - for (const OpAttrInfo& attr : op_convert_info.op_attrs()) { - const std::string& attr_name = attr.attr_name(); + for (const OpAttrInfo &attr : op_convert_info.op_attrs()) { + const std::string &attr_name = attr.attr_name(); ValuePtr attr_value = nullptr; if (!attr_name.empty()) { attr_value = prim->GetAttr(attr_name); @@ -767,15 +767,15 @@ size_t OnnxExporter::ExportPrimitive(const FuncGraphPtr& /*func_graph*/, std::ma MS_LOG(EXCEPTION) << "Primitive " << prim->name() << " does not have attribute " << attr_name; } } - onnx::AttributeProto* onnx_attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *onnx_attr_proto = node_proto->add_attribute(); onnx_attr_proto->set_name(attr.onnx_attr_name()); attr.fn_gen_attr()(attr_value, attr.onnx_attr_type(), onnx_attr_proto, prim); } return node_idx; } -void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeConv(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto conv_node = dyn_cast(node->input(1)); auto input_x = conv_node->input(1); // conv input x auto input_w = conv_node->input(2); // conv weight(filter) @@ -786,8 +786,8 @@ void OnnxExporter::ExportMergeConv(const FuncGraphPtr& func_graph, const CNodePt (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_conv, inputs, graph_proto); } -void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeGemm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { auto matmul_node = dyn_cast(node->input(1)); auto input_x = matmul_node->input(1); // matmul input x auto input_y = matmul_node->input(2); // matmul input y @@ -798,9 +798,9 @@ void OnnxExporter::ExportMergeGemm(const FuncGraphPtr& func_graph, const CNodePt (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_matmul, inputs, graph_proto); } -void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CNodePtr& node, - std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr &func_graph, const CNodePtr &node, + std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { auto batch_norm_node = dyn_cast(node->input(1)); PrimitivePtr prim_batch_norm = dyn_cast((dyn_cast(batch_norm_node->input(0)))->value()); @@ -811,20 +811,20 @@ void OnnxExporter::ExportMergeBatchNorm(const FuncGraphPtr& func_graph, const CN (*node_map_ptr)[node] = ExportPrimitive(func_graph, node_map_ptr, prim_batch_norm, inputs, graph_proto); } -void OnnxExporter::ExportOutput(const FuncGraphPtr& /*func_graph*/, const CNodePtr& node, - std::map* node_map_ptr, onnx::GraphProto* const graph_proto) { +void OnnxExporter::ExportOutput(const FuncGraphPtr & /*func_graph*/, const CNodePtr &node, + std::map *node_map_ptr, onnx::GraphProto *const graph_proto) { if (node->inputs().size() != 2) { MS_LOG(EXCEPTION) << "Number of inputs of return node is not equal to 2."; } AnfNodePtr arg = node->input(1); std::string name = GetNodeInputName(arg, node_map_ptr, graph_proto); - onnx::ValueInfoProto* output_proto = graph_proto->add_output(); + onnx::ValueInfoProto *output_proto = graph_proto->add_output(); output_proto->set_name(name); SetValueInfoType(arg, output_proto, false); } -std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::map* node_map_ptr, - onnx::GraphProto* const graph_proto) { +std::string OnnxExporter::GetNodeInputName(const AnfNodePtr &node, std::map *node_map_ptr, + onnx::GraphProto *const graph_proto) { if (node->isa()) { auto iter = node_map_ptr->find(node); if (iter == node_map_ptr->end()) { @@ -848,7 +848,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::mapadd_node(); + onnx::NodeProto *node_proto = graph_proto->add_node(); node_proto->add_output(node_name); SetNodeAttribute(node->cast()->value(), node_proto); @@ -859,7 +859,7 @@ std::string OnnxExporter::GetNodeInputName(const AnfNodePtr& node, std::maptype_name(); } -void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto* const tensor_proto) { +void OnnxExporter::ConvertTupleToTensor(const ValuePtr &value, onnx::TensorProto *const tensor_proto) { auto tuple_ptr = dyn_cast(value); MS_EXCEPTION_IF_NULL(tuple_ptr); if (tuple_ptr->size() == 0) { @@ -891,14 +891,14 @@ void OnnxExporter::ConvertTupleToTensor(const ValuePtr& value, onnx::TensorProto } } -void OnnxExporter::SetNodeAttribute(const ValuePtr& value, onnx::NodeProto* const node_proto) { +void OnnxExporter::SetNodeAttribute(const ValuePtr &value, onnx::NodeProto *const node_proto) { node_proto->set_op_type("Constant"); - onnx::AttributeProto* attr_proto = node_proto->add_attribute(); + onnx::AttributeProto *attr_proto = node_proto->add_attribute(); attr_proto->set_name("value"); MS_LOG(EXCEPTION) << "Need to set value " << value->ToString() << " attribute for Constant node"; } -std::string GetOnnxProtoString(const FuncGraphPtr& func_graph) { +std::string GetOnnxProtoString(const FuncGraphPtr &func_graph) { OnnxExporter exporter; return exporter.GetOnnxProtoString(func_graph); } diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc index 49dc3ab791..2a3429ca52 100644 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ b/mindspore/ccsrc/operator/cc_implementations.cc @@ -32,12 +32,12 @@ enum class DataType { kInt, kFloat, kDouble, kUnknown }; // Whether has a T type data in AnyPtrList. template -bool HasType(const AnyPtrList& list) { - bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr& ptr) { return ptr->is(); }); +bool HasType(const AnyPtrList &list) { + bool ret = std::any_of(list.begin(), list.end(), [](const AnyPtr &ptr) { return ptr->is(); }); return ret; } -DataType InferType(const AnyPtrList& list) { +DataType InferType(const AnyPtrList &list) { if (HasType(list)) { return DataType::kDouble; } else if (HasType(list)) { @@ -180,7 +180,7 @@ bool InnerScalarGe(T x, U y) { } #define SCALAR_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList& list) { \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ do { \ if (list.size() < 2) { \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ @@ -223,7 +223,7 @@ SCALAR_OP(Pow) SCALAR_OP(Floordiv) #define LOGIC_OP(op_t) \ - ValuePtr Scalar##op_t(const ValuePtrList& list) { \ + ValuePtr Scalar##op_t(const ValuePtrList &list) { \ if (list.size() < 2) { \ MS_LOG(EXCEPTION) << "length of input list for Scalar" << #op_t << " is less than 2."; \ } \ @@ -274,7 +274,7 @@ LOGIC_OP(Ne) LOGIC_OP(Le) LOGIC_OP(Ge) -ValuePtr ScalarUAdd(const ValuePtrList& list) { +ValuePtr ScalarUAdd(const ValuePtrList &list) { if (list.size() != 1) { MS_LOG(EXCEPTION) << "Input number of ScalarUAdd should be 1, but got " << list.size(); } @@ -283,7 +283,7 @@ ValuePtr ScalarUAdd(const ValuePtrList& list) { return x; } -ValuePtr ScalarUSub(const ValuePtrList& list) { +ValuePtr ScalarUSub(const ValuePtrList &list) { if (list.size() != 1) { MS_LOG(EXCEPTION) << "Input number of ScalarUSub should be 1, but got " << list.size(); } @@ -302,7 +302,7 @@ ValuePtr ScalarUSub(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for ScalarUSub, x: " << x->ToString() << "."; } -ValuePtr ScalarLog(const ValuePtrList& list) { +ValuePtr ScalarLog(const ValuePtrList &list) { if (list.empty()) { MS_LOG(EXCEPTION) << "Input list of ScalarLog is empty."; } @@ -321,7 +321,7 @@ ValuePtr ScalarLog(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for ScalarLog, x: " << x->ToString(); } -ValuePtr BoolNot(const ValuePtrList& list) { +ValuePtr BoolNot(const ValuePtrList &list) { if (list.empty()) { MS_LOG(EXCEPTION) << "value list of BoolNot is empty"; } @@ -337,7 +337,7 @@ ValuePtr BoolNot(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolNot, x: " << x->ToString(); } -ValuePtr BoolAnd(const ValuePtrList& list) { +ValuePtr BoolAnd(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolAnd is less then 2."; } @@ -356,7 +356,7 @@ ValuePtr BoolAnd(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolAnd, x: " << x->ToString() << "."; } -ValuePtr BoolOr(const ValuePtrList& list) { +ValuePtr BoolOr(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolOr is less then 2."; } @@ -375,7 +375,7 @@ ValuePtr BoolOr(const ValuePtrList& list) { MS_LOG(EXCEPTION) << "Unsported Value for BoolOr, x: " << x->ToString() << "."; } -ValuePtr BoolEq(const ValuePtrList& list) { +ValuePtr BoolEq(const ValuePtrList &list) { if (list.size() < 2) { MS_LOG(EXCEPTION) << "Input number " << list.size() << " of BoolEq is less than 2."; } diff --git a/mindspore/ccsrc/operator/cc_implementations.h b/mindspore/ccsrc/operator/cc_implementations.h index 69981cea7d..cef34da4f4 100644 --- a/mindspore/ccsrc/operator/cc_implementations.h +++ b/mindspore/ccsrc/operator/cc_implementations.h @@ -29,29 +29,29 @@ namespace prim { using Any = mindspore::Any; using AnyPtrList = std::vector>; using ValuePtrList = std::vector; -using OpsFunction = std::function; -using AnfNodeOpsFunction = std::function&)>; +using OpsFunction = std::function; +using AnfNodeOpsFunction = std::function &)>; -ValuePtr ScalarAdd(const ValuePtrList& list); -ValuePtr ScalarSub(const ValuePtrList& list); -ValuePtr ScalarMul(const ValuePtrList& list); -ValuePtr ScalarDiv(const ValuePtrList& list); -ValuePtr ScalarMod(const ValuePtrList& list); -ValuePtr ScalarPow(const ValuePtrList& list); -ValuePtr ScalarFloordiv(const ValuePtrList& list); -ValuePtr ScalarUAdd(const ValuePtrList& list); -ValuePtr ScalarUSub(const ValuePtrList& list); -ValuePtr ScalarLog(const ValuePtrList& list); -ValuePtr ScalarEq(const ValuePtrList& list); -ValuePtr ScalarLt(const ValuePtrList& list); -ValuePtr ScalarGt(const ValuePtrList& list); -ValuePtr ScalarNe(const ValuePtrList& list); -ValuePtr ScalarLe(const ValuePtrList& list); -ValuePtr ScalarGe(const ValuePtrList& list); -ValuePtr BoolNot(const ValuePtrList& list); -ValuePtr BoolAnd(const ValuePtrList& list); -ValuePtr BoolOr(const ValuePtrList& list); -ValuePtr BoolEq(const ValuePtrList& list); +ValuePtr ScalarAdd(const ValuePtrList &list); +ValuePtr ScalarSub(const ValuePtrList &list); +ValuePtr ScalarMul(const ValuePtrList &list); +ValuePtr ScalarDiv(const ValuePtrList &list); +ValuePtr ScalarMod(const ValuePtrList &list); +ValuePtr ScalarPow(const ValuePtrList &list); +ValuePtr ScalarFloordiv(const ValuePtrList &list); +ValuePtr ScalarUAdd(const ValuePtrList &list); +ValuePtr ScalarUSub(const ValuePtrList &list); +ValuePtr ScalarLog(const ValuePtrList &list); +ValuePtr ScalarEq(const ValuePtrList &list); +ValuePtr ScalarLt(const ValuePtrList &list); +ValuePtr ScalarGt(const ValuePtrList &list); +ValuePtr ScalarNe(const ValuePtrList &list); +ValuePtr ScalarLe(const ValuePtrList &list); +ValuePtr ScalarGe(const ValuePtrList &list); +ValuePtr BoolNot(const ValuePtrList &list); +ValuePtr BoolAnd(const ValuePtrList &list); +ValuePtr BoolOr(const ValuePtrList &list); +ValuePtr BoolEq(const ValuePtrList &list); std::vector BroadcastShape_(std::vector s1, std::vector s2); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index 9a665e8a30..bf0dcf37d4 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -66,7 +66,7 @@ const MetaFuncGraphPtr kTail = std::make_shared("tail"); // Apply a function of two arguments cumulatively to the items of a sequence, // from left to right, so as to reduce the sequence to a single value.For example, // reduce(lambda x, y: x + y, [ 1, 2, 3, 4, 5 ]) calculates ((((1 + 2) + 3) + 4) + 5). -AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { +AnyPtr Reduce(const OpsFunction &func, const AnyPtrList &list) { std::shared_ptr ret; size_t size = list.size(); if (size < 2) { @@ -88,7 +88,7 @@ AnyPtr Reduce(const OpsFunction& func, const AnyPtrList& list) { return ret; } -AnfNodePtr Reduce(const AnfNodeOpsFunction& func, const std::vector& list) { +AnfNodePtr Reduce(const AnfNodeOpsFunction &func, const std::vector &list) { size_t size = list.size(); if (size < 2) { MS_LOG(EXCEPTION) << "length of inputs of Reduce is less than 2"; @@ -121,7 +121,7 @@ void HyperMap::Init() { {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); } -HyperMap::HyperMap(const std::shared_ptr& fn_leaf) +HyperMap::HyperMap(const std::shared_ptr &fn_leaf) : MetaFuncGraph("hyper_map"), fn_leaf_(fn_leaf), broadcast_(false), @@ -129,13 +129,13 @@ HyperMap::HyperMap(const std::shared_ptr& fn_leaf) Init(); } -HyperMap::HyperMap(const HyperMap& h) +HyperMap::HyperMap(const HyperMap &h) : MetaFuncGraph("hyper_map"), fn_leaf_(h.fn_leaf_), broadcast_(h.broadcast_), nonleaf_(h.nonleaf_) { Init(); } -AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); std::vector inputs; if (fn_arg != nullptr) { @@ -145,17 +145,17 @@ AnfNodePtr HyperMap::FullMake(TypePtr, const FuncGraphPtr& func_graph, const Anf } (void)std::transform(arg_map.begin(), arg_map.end(), std::back_inserter(inputs), - [](const std::pair& item) { return item.first; }); + [](const std::pair &item) { return item.first; }); return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { auto lhs = std::static_pointer_cast(item.second); MS_EXCEPTION_IF_NULL(lhs); return lhs->elements().size() != size; @@ -179,7 +179,7 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraph (void)std::transform( arg_map.begin(), arg_map.end(), std::back_inserter(inputs2), - [&func_graph, i](const std::pair& item) { + [&func_graph, i](const std::pair &item) { return func_graph->NewCNode({NewValueNode(prim::kPrimListGetItem), item.first, NewValueNode(i)}); }); @@ -188,13 +188,13 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraph return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(type); std::size_t size = type->elements().size(); - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [size](const std::pair &item) { auto lhs = std::static_pointer_cast(item.second); MS_EXCEPTION_IF_NULL(lhs); return lhs->elements().size() != size; @@ -226,8 +226,8 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGrap return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, - const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, + const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { MS_EXCEPTION_IF_NULL(type); MS_EXCEPTION_IF_NULL(func_graph); @@ -257,11 +257,11 @@ AnfNodePtr HyperMap::FullMake(const std::shared_ptr& type, const FuncGrap return func_graph->NewCNode(inputs); } -AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map) { +AnfNodePtr HyperMap::Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) { bool found = false; TypeId id = kObjectTypeEnd; std::pair pair; - for (auto& item : arg_map) { + for (auto &item : arg_map) { pair = item; id = item.second->type_id(); if (nonleaf_.count(id)) { @@ -272,7 +272,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a if (found) { // In a nonleaf situation, all arguments must have the same generic. - bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair& item) { + bool is_not_same = std::any_of(arg_map.begin(), arg_map.end(), [pair](const std::pair &item) { if (item.first != pair.first) { return item.second->type_id() != pair.second->type_id(); } @@ -283,7 +283,7 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a oss << "There are " << arg_map.size() << " inputs of `" << name_ << "`, corresponding type info:\n" << trace::GetDebugInfo(func_graph->debug_info()) << "\n"; int idx = 0; - for (auto& item : arg_map) { + for (auto &item : arg_map) { oss << ++idx << ": " << item.second->ToString() << "\n"; } MS_LOG(EXCEPTION) << "HyperMap cannot match up all input types of arguments.\n" << oss.str(); @@ -308,14 +308,14 @@ AnfNodePtr HyperMap::Make(const FuncGraphPtr& func_graph, const AnfNodePtr& fn_a } } -ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairList& args_spec_list) { +ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairList &args_spec_list) { TypePtr type_tensor = std::make_shared(); bool flag = std::any_of( args_spec_list.begin(), args_spec_list.end(), - [type_tensor](const std::pair& item) { return IsSubType(item.second, type_tensor); }); + [type_tensor](const std::pair &item) { return IsSubType(item.second, type_tensor); }); if (flag && broadcast_) { ArgsPairList ret; - for (auto& item : args_spec_list) { + for (auto &item : args_spec_list) { if (!IsSubType(item.second, type_tensor)) { TypePtr type_tensor_ele = std::make_shared(item.second); ret.push_back( @@ -329,7 +329,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr& func_graph, const ArgsPairL return args_spec_list; } -FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { +FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr ptrGraph = std::make_shared(); ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); ptrGraph->debug_info()->set_name("hyper_map"); @@ -353,7 +353,7 @@ FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList& args_spec_list) { return ptrGraph; } -abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& args_spec_list) const { +abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const { if (fn_leaf_ == nullptr) { MS_EXCEPTION_IF_NULL(args_spec_list[0]); // Assert that hypermap's function param does not contain free variables @@ -368,20 +368,20 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList& AbstractBasePtrList broadened; (void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(broadened), - [](const AbstractBasePtr& arg) -> AbstractBasePtr { + [](const AbstractBasePtr &arg) -> AbstractBasePtr { MS_EXCEPTION_IF_NULL(arg); return arg->Broaden(); }); return broadened; } -REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(HyperMap_, ([](const py::module *m) { (void)py::class_>(*m, "HyperMap_") .def(py::init>(), py::arg("leaf")) .def(py::init<>()); })); -FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple) { +FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple) { MS_EXCEPTION_IF_NULL(a_tuple); FuncGraphPtr ret = std::make_shared(); @@ -401,7 +401,7 @@ FuncGraphPtr Tail::GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tu return ret; } -FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list) { +FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr &a_list) { MS_EXCEPTION_IF_NULL(a_list); FuncGraphPtr ret = std::make_shared(); @@ -421,7 +421,7 @@ FuncGraphPtr Tail::GenerateListFuncGraph(const abstract::AbstractListPtr& a_list return ret; } -FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() != 1) { MS_LOG(EXCEPTION) << "tail requires a non-empty tuple."; } @@ -441,11 +441,11 @@ FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) } REGISTER_PYBIND_DEFINE( - Tail_, ([](const py::module* m) { - (void)py::class_>(*m, "Tail_").def(py::init()); + Tail_, ([](const py::module *m) { + (void)py::class_>(*m, "Tail_").def(py::init()); })); -FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { int tuple_size = SizeToInt(args_spec_list.size()); std::ostringstream ss; @@ -486,7 +486,7 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList& arg return fg; } -GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_list, bool sens_param) +GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param) : MetaFuncGraph(name), get_all_(get_all), get_by_list_(get_by_list), sens_param_(sens_param) { if (get_by_list) { signatures_ = @@ -496,8 +496,8 @@ GradOperation::GradOperation(const std::string& name, bool get_all, bool get_by_ } } -FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, - const std::vector& params_list, bool applyJ) { +FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights, + const std::vector ¶ms_list, bool applyJ) { FuncGraphPtr ret = std::make_shared(); ret->set_flags(FUNC_GRAPH_FLAG_CORE, true); @@ -537,7 +537,7 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr& weights, return ret; } -void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, +void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights, ValueNodePtr opsTupleItem) { MS_EXCEPTION_IF_NULL(func_graph); @@ -590,7 +590,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr out, An } // Generate the graph. -FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { if (args_spec_list.size() < 1) { MS_LOG(EXCEPTION) << "GenerateGraph requires at least 1 parameters, while the input size is " << args_spec_list.size() << "."; @@ -637,21 +637,21 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList& args_sp return dfBuilder; } -REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) { (void)py::class_>( *m, "GradOperation_") - .def(py::init(), py::arg("fn")) - .def(py::init(), py::arg("fn"), py::arg("get_all"), + .def(py::init(), py::arg("fn")) + .def(py::init(), py::arg("fn"), py::arg("get_all"), py::arg("get_by_list"), py::arg("sens_param")); })); -MultitypeFuncGraph::MultitypeFuncGraph(const std::string& name) : MetaFuncGraph(name) { +MultitypeFuncGraph::MultitypeFuncGraph(const std::string &name) : MetaFuncGraph(name) { fn_cache_.clear(); signatures_ = std::vector({// def multitype(*args:ref): {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); } -void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) { +void MultitypeFuncGraph::Register(const TypePtrList &types, specialize_fn s_fn) { MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << "."; auto fn = fn_cache_.find(types); if (fn != fn_cache_.end()) { @@ -660,7 +660,7 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, specialize_fn s_fn) fn_cache_[types] = s_fn; } -void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& py_fn) { +void MultitypeFuncGraph::Register(const TypePtrList &types, const py::function &py_fn) { MS_LOG(DEBUG) << "Register type (" << ::mindspore::ToString(types) << ", " << std::string(py_fn.str()) << ")."; auto fn = fn_cache_.find(types); if (fn != fn_cache_.end()) { @@ -669,9 +669,9 @@ void MultitypeFuncGraph::Register(const TypePtrList& types, const py::function& fn_cache_py_[types] = py_fn; } -void MultitypeFuncGraph::Register(const std::vector& types_name, const py::function& py_fn) { +void MultitypeFuncGraph::Register(const std::vector &types_name, const py::function &py_fn) { TypePtrList types; - for (auto& type_name : types_name) { + for (auto &type_name : types_name) { auto type_ptr = StringToType(type_name); if (type_ptr == nullptr) { MS_LOG(EXCEPTION) << "" << type_name << " convert from string error "; @@ -681,7 +681,7 @@ void MultitypeFuncGraph::Register(const std::vector& types_name, co Register(types, py_fn); } -void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& py_fn) { +void MultitypeFuncGraph::PyRegister(const py::tuple &tuple, const py::function &py_fn) { std::vector types_name; for (size_t it = 0; it < tuple.size(); ++it) { py::object name_py = tuple[it]; @@ -693,16 +693,16 @@ void MultitypeFuncGraph::PyRegister(const py::tuple& tuple, const py::function& } Register(types_name, py_fn); } -static TypePtr UnwrapRef(const TypePtr& type) { +static TypePtr UnwrapRef(const TypePtr &type) { if (type->isa()) { return type->cast()->subtype(); } return type; } -FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { +FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { bool find_fn = false; py::function py_fn; - for (auto& item : fn_cache_py_) { + for (auto &item : fn_cache_py_) { TypePtrList sign = item.first; if (sign.size() != types.size()) { continue; @@ -735,7 +735,7 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { oss << "There are " << fn_cache_py_.size() << " prototypes for overload function `" << name_ << "`, corresponding location info:\n"; int idx = 0; - for (auto& item : fn_cache_py_) { + for (auto &item : fn_cache_py_) { FuncGraphPtr func_graph = parse::ParsePythonCode(item.second); if (func_graph == nullptr) { MS_LOG(WARNING) << "Fail to parse Python code for function `" << name_ << "`."; @@ -747,15 +747,15 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList& types) { << oss.str(); } -REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(MultitypeFuncGraph_, ([](const py::module *m) { (void)py::class_>( *m, "MultitypeFuncGraph_") - .def(py::init()) + .def(py::init()) .def("register_fn", &MultitypeFuncGraph::PyRegister); })); // Generate the ListMap func graph. -FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { size_t args_num = args_spec_list.size(); // args: fn, list1, list2, ... if (args_num < 2) { @@ -821,8 +821,8 @@ FuncGraphPtr ListMap::GenerateFuncGraph(const AbstractBasePtrList& args_spec_lis return fg_ptr; } -void ListMap::MakeCond(const std::vector& lists, const FuncGraphPtr& fgnext_ptr, - const FuncGraphPtr& fg_ptr) { +void ListMap::MakeCond(const std::vector &lists, const FuncGraphPtr &fgnext_ptr, + const FuncGraphPtr &fg_ptr) { MS_EXCEPTION_IF_NULL(fg_ptr); AnfNodePtr fn = fg_ptr->add_parameter(); @@ -858,8 +858,8 @@ void ListMap::MakeCond(const std::vector& lists, const FuncGraphPtr& fgtrue_ptr->set_output(output_cnode); } -void ListMap::MakeNext(const std::vector& lists, const FuncGraphPtr& fgcond_ptr, - const FuncGraphPtr& fg_ptr) { +void ListMap::MakeNext(const std::vector &lists, const FuncGraphPtr &fgcond_ptr, + const FuncGraphPtr &fg_ptr) { MS_EXCEPTION_IF_NULL(fg_ptr); AnfNodePtr fn = fg_ptr->add_parameter(); @@ -893,7 +893,7 @@ void ListMap::MakeNext(const std::vector& lists, const FuncGraphPtr& fg_ptr->set_output(output_cnode); } -FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // args: tuple1, tuple2 abstract::CheckArgsSize("TupleAdd", args_spec_list, 2); AbstractBasePtr abs_a = args_spec_list[0]; @@ -928,7 +928,7 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList& args_spec_li return ret; } -int GetArgScalarValue(const abstract::AbstractScalarPtr& scalar, const std::string&) { +int GetArgScalarValue(const abstract::AbstractScalarPtr &scalar, const std::string &) { MS_EXCEPTION_IF_NULL(scalar); return GetValue(scalar->BuildValue()); } @@ -942,7 +942,7 @@ int GetPositiveIndex(int index, int length) { return index; } -int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std::string& member_name) { +int CheckSliceMember(const AbstractBasePtr &member, int default_value, const std::string &member_name) { MS_EXCEPTION_IF_NULL(member); if (member->isa()) { @@ -957,8 +957,8 @@ int CheckSliceMember(const AbstractBasePtr& member, int default_value, const std << member->ToString(); } -void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSlicePtr& slice, int* start_index, - int* stop_index, int* step_value) { +void GenerateTupleSliceParameter(const AbstractTuplePtr &tuple, const AbstractSlicePtr &slice, int *start_index, + int *stop_index, int *step_value) { MS_EXCEPTION_IF_NULL(tuple); MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(start_index); @@ -998,7 +998,7 @@ void GenerateTupleSliceParameter(const AbstractTuplePtr& tuple, const AbstractSl } } -FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tuple // args: tuple, start index, end index, step const std::string op_name("TupleSlice"); @@ -1032,7 +1032,7 @@ FuncGraphPtr TupleSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ return ret; } -int ConvertBinaryToDecimal(const std::vector& number_bin) { +int ConvertBinaryToDecimal(const std::vector &number_bin) { unsigned int number_dec = 0; for (size_t index = 0; index < number_bin.size(); index++) { number_dec |= number_bin[index] << index; @@ -1040,8 +1040,8 @@ int ConvertBinaryToDecimal(const std::vector& number_bin) { return static_cast(number_dec); } -void ParseSlice(const AbstractSlicePtr& slice, std::vector* begin, std::vector* end, - std::vector* strides, int length) { +void ParseSlice(const AbstractSlicePtr &slice, std::vector *begin, std::vector *end, + std::vector *strides, int length) { MS_EXCEPTION_IF_NULL(slice); MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); @@ -1064,8 +1064,8 @@ void ParseSlice(const AbstractSlicePtr& slice, std::vector* begin, std::vec strides->push_back(step_value); } -int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, const std::vector& shape, - std::vector* begin, std::vector* end, std::vector* strides) { +int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, const std::vector &shape, + std::vector *begin, std::vector *end, std::vector *strides) { MS_EXCEPTION_IF_NULL(slice_tuple); MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); @@ -1111,8 +1111,8 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr& slice_tuple, return ConvertBinaryToDecimal(shrink); } -int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const std::vector& shape, - std::vector* begin, std::vector* end, std::vector* strides) { +int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr &slice, const std::vector &shape, + std::vector *begin, std::vector *end, std::vector *strides) { MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(strides); @@ -1132,9 +1132,9 @@ int GenerateStridedSliceParametersFromSlice(const AbstractSlicePtr& slice, const return 0; } -int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, const std::vector& shape, - std::vector* begin, std::vector* end, - std::vector* strides) { +int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr &scalar, const std::vector &shape, + std::vector *begin, std::vector *end, + std::vector *strides) { MS_EXCEPTION_IF_NULL(begin); MS_EXCEPTION_IF_NULL(end); MS_EXCEPTION_IF_NULL(strides); @@ -1153,7 +1153,7 @@ int GenerateStridedSliceParametersFromNumber(const AbstractScalarPtr& scalar, co return 1; } -FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tensor // args: tensor, slice or slice tuple const std::string op_name = std::string("TensorSlice"); @@ -1177,7 +1177,7 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); } else { std::ostringstream args_info; - for (const auto& arg : args_spec_list) { + for (const auto &arg : args_spec_list) { MS_EXCEPTION_IF_NULL(arg); args_info << arg->ToString() << "\n"; } @@ -1199,19 +1199,19 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList& args_spec return ret_graph; } -REGISTER_PYBIND_DEFINE( - TupleAdd_, ([](const py::module* m) { - (void)py::class_>(*m, "TupleAdd_").def(py::init()); - })); +REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { + (void)py::class_>(*m, "TupleAdd_") + .def(py::init()); + })); -REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(TupleSlice_, ([](const py::module *m) { (void)py::class_>(*m, "TupleSlice_") - .def(py::init()); + .def(py::init()); })); -REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { (void)py::class_>(*m, "TensorSlice_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index dc8627ba61..1dad2e08cf 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -47,20 +47,20 @@ using ArgsPairList = std::vector>; class MultitypeFuncGraph : public MetaFuncGraph { public: - explicit MultitypeFuncGraph(const std::string& name); + explicit MultitypeFuncGraph(const std::string &name); ~MultitypeFuncGraph() override = default; MS_DECLARE_PARENT(MultitypeFuncGraph, MetaFuncGraph) - using specialize_fn = FuncGraph* (*)(TypePtrList); + using specialize_fn = FuncGraph *(*)(TypePtrList); // Register a method which specialize based on types vectors; - virtual void Register(const TypePtrList& types, specialize_fn s_fn); - virtual void Register(const TypePtrList& types, const py::function& py_fn); - virtual void Register(const std::vector& types_name, const py::function& py_fn); - virtual void PyRegister(const py::tuple& tuple, const py::function& py_fn); + virtual void Register(const TypePtrList &types, specialize_fn s_fn); + virtual void Register(const TypePtrList &types, const py::function &py_fn); + virtual void Register(const std::vector &types_name, const py::function &py_fn); + virtual void PyRegister(const py::tuple &tuple, const py::function &py_fn); - FuncGraphPtr GenerateFromTypes(const TypePtrList& types) override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &types) override; size_t GetPyFnCacheSize() const { return fn_cache_py_.size(); } - const std::unordered_map& GetPyFunctions() const { + const std::unordered_map &GetPyFunctions() const { return fn_cache_py_; } @@ -72,10 +72,10 @@ using MultitypeFuncGraphPtr = std::shared_ptr; class HyperMap : public MetaFuncGraph { public: - explicit HyperMap(const std::shared_ptr& fn_leaf = nullptr); - HyperMap(const HyperMap& h); + explicit HyperMap(const std::shared_ptr &fn_leaf = nullptr); + HyperMap(const HyperMap &h); void Init(); - HyperMap& operator=(const HyperMap& h) { + HyperMap &operator=(const HyperMap &h) { if (this != &h) { fn_leaf_ = h.fn_leaf_; broadcast_ = h.broadcast_; @@ -89,21 +89,21 @@ class HyperMap : public MetaFuncGraph { ~HyperMap() override = default; MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) - abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList& args_spec_list) const override; - FuncGraphPtr GenerateFromTypes(const TypePtrList& args_spec_list) override; + abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; + FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } private: - AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr FullMake(const std::shared_ptr& type, const FuncGraphPtr& func_graph, const AnfNodePtr& fn_arg, - const ArgsPairList& arg_map); - AnfNodePtr Make(const FuncGraphPtr& graph, const AnfNodePtr& fn_arg, const ArgsPairList& arg_map); - ArgsPairList Harmonize(const FuncGraphPtr& graph, const ArgsPairList& args_spec_list); + AnfNodePtr FullMake(TypePtr type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr FullMake(const std::shared_ptr &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, + const ArgsPairList &arg_map); + AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); + ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); MultitypeFuncGraphPtr fn_leaf_; bool broadcast_; @@ -113,7 +113,7 @@ using HyperMapPtr = std::shared_ptr; class HyperMapPy : public HyperMap { public: - explicit HyperMapPy(const std::shared_ptr& fn_leaf = nullptr) : HyperMap(fn_leaf) {} + explicit HyperMapPy(const std::shared_ptr &fn_leaf = nullptr) : HyperMap(fn_leaf) {} ~HyperMapPy() override = default; MS_DECLARE_PARENT(HyperMapPy, HyperMap) }; @@ -123,56 +123,56 @@ extern ValuePtr kCompositeHyperMap; class Tail : public MetaFuncGraph { public: - explicit Tail(const std::string& name) : MetaFuncGraph(name) {} + explicit Tail(const std::string &name) : MetaFuncGraph(name) {} ~Tail() override = default; MS_DECLARE_PARENT(Tail, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr& a_tuple); - FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr& a_list); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + FuncGraphPtr GenerateTupleFuncGraph(const abstract::AbstractTuplePtr &a_tuple); + FuncGraphPtr GenerateListFuncGraph(const abstract::AbstractListPtr &a_list); - friend bool operator==(const Tail& lhs, const Tail& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } }; using TailPtr = std::shared_ptr; class MakeTupleGradient : public MetaFuncGraph { public: - explicit MakeTupleGradient(const std::string& name) : MetaFuncGraph(name) {} + explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} ~MakeTupleGradient() override = default; MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const MakeTupleGradient& lhs, const MakeTupleGradient& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } }; using MakeTupleGradientPtr = std::shared_ptr; class GradOperation : public MetaFuncGraph { public: - explicit GradOperation(const std::string& name, bool get_all = false, bool get_by_list = false, + explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, bool sens_param = false); ~GradOperation() override = default; MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) - FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr& weights, const std::vector& ptrParams, + FuncGraphPtr GetGrad(AnfNodePtr ptrNode, const AnfNodePtr &weights, const std::vector &ptrParams, bool applyJ = false); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; bool sens_param() const { return sens_param_; } bool get_all_; bool get_by_list_; bool sens_param_; private: - void doGetGrad(const FuncGraphPtr& func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, + void doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr ptrOut, AnfNodePtr ptrBprop, AnfNodePtr weights, ValueNodePtr opsTupleItem); }; using GradOperationPtr = std::shared_ptr; class ListMap { public: - explicit ListMap(const std::string& name) : name_(name) { cache_.clear(); } + explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } ~ListMap() = default; - void MakeCond(const std::vector& lists, const FuncGraphPtr& gnext_ptr, const FuncGraphPtr& graph_ptr); - void MakeNext(const std::vector& lists, const FuncGraphPtr& gcond_ptr, const FuncGraphPtr& graph_ptr); - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list); + void MakeCond(const std::vector &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); + void MakeNext(const std::vector &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); private: std::string name_; @@ -181,31 +181,31 @@ class ListMap { class TupleAdd : public MetaFuncGraph { public: - explicit TupleAdd(const std::string& name) : MetaFuncGraph(name) {} + explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} ~TupleAdd() override = default; MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TupleAdd& lhs, const TupleAdd& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } }; using TupleAddPtr = std::shared_ptr; class TupleSlice : public MetaFuncGraph { public: - explicit TupleSlice(const std::string& name) : MetaFuncGraph(name) {} + explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} ~TupleSlice() override = default; MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TupleSlice& lhs, const TupleSlice& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } }; using TupleSlicePtr = std::shared_ptr; class TensorSlice : public MetaFuncGraph { public: - explicit TensorSlice(const std::string& name) : MetaFuncGraph(name) {} + explicit TensorSlice(const std::string &name) : MetaFuncGraph(name) {} ~TensorSlice() override = default; MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const TensorSlice& lhs, const TensorSlice& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } }; using TensorSlicePtr = std::shared_ptr; diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index a4a26377f5..95e38247d9 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -34,7 +34,7 @@ namespace prim { namespace { using PatternListType = std::initializer_list; -const std::vector& GetSignature(const ValuePtr& function) { +const std::vector &GetSignature(const ValuePtr &function) { static const auto empty = std::vector(); if (function->isa()) { return function->cast()->signatures(); @@ -44,8 +44,8 @@ const std::vector& GetSignature(const ValuePtr& function) { return empty; } -void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& args_spec_list, - const std::vector& signature, bool has_var, std::vector* op_inputs) { +void ProcessDefault(const std::string &func_name, const AbstractBasePtrList &args_spec_list, + const std::vector &signature, bool has_var, std::vector *op_inputs) { std::size_t sig_size = signature.size(); auto positional_size = sig_size; if (has_var) { @@ -64,8 +64,8 @@ void ProcessDefault(const std::string& func_name, const AbstractBasePtrList& arg } // Get the largest type of index in the same SignatureEnumDType of arguments. -std::map GetMaxDtypeIndex(const std::vector& dtypes, - const abstract::AbstractBasePtrList& args_spec_list) { +std::map GetMaxDtypeIndex(const std::vector &dtypes, + const abstract::AbstractBasePtrList &args_spec_list) { // record index for signature.dtypes of the same type // eg. [T, T1, T, T2, T, T1, T3] -> {{T:(0,2,4)}, {T1:(1,5)}, {T2:(3)}, {T3:(6)}} std::map> type_indexs; @@ -89,7 +89,7 @@ std::map GetMaxDtypeIndex(const std::vectorisa()) { arg_value = arg_value->cast()->ref(); @@ -104,7 +104,7 @@ std::map GetMaxDtypeIndex(const std::vector& signature, const abstract::AbstractBasePtrList& args_spec_list, - const FuncGraphPtr& graph, std::vector* op_inputs) { +void DoAutoCast(const std::vector &signature, const abstract::AbstractBasePtrList &args_spec_list, + const FuncGraphPtr &graph, std::vector *op_inputs) { std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature& sig) { return sig.dtype; }); + [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.empty() || static_cast(dtypes.size()) == empty_dtype_count) { return; @@ -143,10 +143,10 @@ void DoAutoCast(const std::vector& signature, const abstract::Abstrac } } -AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const std::vector& params_list) { +AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const std::vector ¶ms_list) { // args: original inputs - auto& signature = GetSignature(function); + auto &signature = GetSignature(function); std::size_t sig_size = signature.size(); auto has_var = (sig_size > 0 && signature[sig_size - 1].kind == SignatureEnumKind::kKindVarPositional); if (sig_size > 0) { @@ -196,13 +196,13 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr& func_graph, const std::string& func } } // namespace -AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs) { +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs) { auto new_cnode = BuildNewCNode(func_graph, func_name, function, args_spec_list, old_node_inputs); return new_cnode; } -FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr DoSignatureMetaFuncGraph::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { FuncGraphPtr func_graph = std::make_shared(); for (size_t i = 0; i < args_spec_list.size(); ++i) { diff --git a/mindspore/ccsrc/operator/composite/do_signature.h b/mindspore/ccsrc/operator/composite/do_signature.h index b88053e224..3e1596d63f 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.h +++ b/mindspore/ccsrc/operator/composite/do_signature.h @@ -37,17 +37,17 @@ namespace mindspore { namespace prim { class DoSignatureMetaFuncGraph : public MetaFuncGraph { public: - explicit DoSignatureMetaFuncGraph(const std::string& name, const ValuePtr& function) + explicit DoSignatureMetaFuncGraph(const std::string &name, const ValuePtr &function) : MetaFuncGraph("S-" + name), function_(function) {} ~DoSignatureMetaFuncGraph() override = default; MS_DECLARE_PARENT(DoSignatureMetaFuncGraph, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& args_spec_list) override; + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &args_spec_list) override; const ValuePtr function() const { return function_; } - friend bool operator==(const DoSignatureMetaFuncGraph& lhs, const DoSignatureMetaFuncGraph& rhs) { + friend bool operator==(const DoSignatureMetaFuncGraph &lhs, const DoSignatureMetaFuncGraph &rhs) { return &lhs == &rhs; } @@ -56,8 +56,8 @@ class DoSignatureMetaFuncGraph : public MetaFuncGraph { }; using RWSignaturePtr = std::shared_ptr; -AnfNodePtr GenerateCNode(const FuncGraphPtr& func_graph, const std::string& func_name, const ValuePtr& function, - const AbstractBasePtrList& args_spec_list, const AnfNodePtrList& old_node_inputs); +AnfNodePtr GenerateCNode(const FuncGraphPtr &func_graph, const std::string &func_name, const ValuePtr &function, + const AbstractBasePtrList &args_spec_list, const AnfNodePtrList &old_node_inputs); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.cc b/mindspore/ccsrc/operator/composite/list_append_operation.cc index 8621a8a8ba..b5a4fc626e 100644 --- a/mindspore/ccsrc/operator/composite/list_append_operation.cc +++ b/mindspore/ccsrc/operator/composite/list_append_operation.cc @@ -27,7 +27,7 @@ namespace mindspore { // namespace to support composite operators definition namespace prim { -FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& args_list) { +FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList &args_list) { abstract::CheckArgsSize("ListAppend", args_list, 2); AbstractBasePtr arg0 = args_list[0]; @@ -52,9 +52,9 @@ FuncGraphPtr ListAppend::GenerateFuncGraph(const abstract::AbstractBasePtrList& return ret; } -REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(ListAppend_, ([](const py::module *m) { (void)py::class_>(*m, "ListAppend_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/list_append_operation.h b/mindspore/ccsrc/operator/composite/list_append_operation.h index f34b6b864e..1da3f9a009 100644 --- a/mindspore/ccsrc/operator/composite/list_append_operation.h +++ b/mindspore/ccsrc/operator/composite/list_append_operation.h @@ -28,15 +28,15 @@ namespace mindspore { namespace prim { class ListAppend : public MetaFuncGraph { public: - explicit ListAppend(const std::string& name) : MetaFuncGraph(name) {} + explicit ListAppend(const std::string &name) : MetaFuncGraph(name) {} ~ListAppend() override = default; MS_DECLARE_PARENT(ListAppend, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList& a_list) override; - friend std::ostream& operator<<(std::ostream& os, const ListAppend& list_append) { + FuncGraphPtr GenerateFuncGraph(const abstract::AbstractBasePtrList &a_list) override; + friend std::ostream &operator<<(std::ostream &os, const ListAppend &list_append) { os << list_append.name_; return os; } - friend bool operator==(const ListAppend& lhs, const ListAppend& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const ListAppend &lhs, const ListAppend &rhs) { return lhs.name_ == rhs.name_; } }; using ListAppendPtr = std::shared_ptr; } // namespace prim diff --git a/mindspore/ccsrc/operator/composite/unpack_call.cc b/mindspore/ccsrc/operator/composite/unpack_call.cc index 64d6b3433b..122f276657 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.cc +++ b/mindspore/ccsrc/operator/composite/unpack_call.cc @@ -40,7 +40,7 @@ using mindspore::abstract::AbstractKeywordArg; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuplePtr; -FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // slice a tensor // args: tensor, slice or slice tuple const std::string op_name = std::string("UnpackCall"); @@ -70,7 +70,7 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ AnfNodePtr para_dict = ret_graph->add_parameter(); auto dict_elems = arg_dict->elements(); (void)std::transform(dict_elems.begin(), dict_elems.end(), std::back_inserter(elems), - [ret_graph, para_dict](const AbstractAttribute& item) { + [ret_graph, para_dict](const AbstractAttribute &item) { auto dict_get_item = ret_graph->NewCNode( {NewValueNode(prim::kPrimDictGetItem), para_dict, NewValueNode(item.first)}); return ret_graph->NewCNode( @@ -85,9 +85,9 @@ FuncGraphPtr UnpackCall::GenerateFuncGraph(const AbstractBasePtrList& args_spec_ return ret_graph; } -REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(UnpackCall_, ([](const py::module *m) { (void)py::class_>(*m, "UnpackCall_") - .def(py::init()); + .def(py::init()); })); } // namespace prim diff --git a/mindspore/ccsrc/operator/composite/unpack_call.h b/mindspore/ccsrc/operator/composite/unpack_call.h index 7ec5f9ad33..2f39615c1a 100644 --- a/mindspore/ccsrc/operator/composite/unpack_call.h +++ b/mindspore/ccsrc/operator/composite/unpack_call.h @@ -40,11 +40,11 @@ namespace prim { // and generate positional parameters and key-value pairs for function. class UnpackCall : public MetaFuncGraph { public: - explicit UnpackCall(const std::string& name) : MetaFuncGraph(name) {} + explicit UnpackCall(const std::string &name) : MetaFuncGraph(name) {} ~UnpackCall() override = default; MS_DECLARE_PARENT(UnpackCall, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend bool operator==(const UnpackCall& lhs, const UnpackCall& rhs) { return lhs.name_ == rhs.name_; } + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend bool operator==(const UnpackCall &lhs, const UnpackCall &rhs) { return lhs.name_ == rhs.name_; } }; using UnpackCallPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/operator/composite/zip_operation.cc b/mindspore/ccsrc/operator/composite/zip_operation.cc index b87e19b009..4d34163f28 100644 --- a/mindspore/ccsrc/operator/composite/zip_operation.cc +++ b/mindspore/ccsrc/operator/composite/zip_operation.cc @@ -36,7 +36,7 @@ namespace prim { using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractTuple; -FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) { +FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // zip operation: // input: tuple arguments // output: tuple of items of input iterated on every input @@ -44,7 +44,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; } - auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr& abs) -> bool { + auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { MS_EXCEPTION_IF_NULL(abs); return abs->isa(); }); @@ -53,7 +53,7 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe } auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), - [](const AbstractBasePtr& x, const AbstractBasePtr& y) { + [](const AbstractBasePtr &x, const AbstractBasePtr &y) { return (x->cast()->size() < y->cast()->size()); }); FuncGraphPtr ret_graph = std::make_shared(); @@ -81,10 +81,10 @@ FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList& args_spe return ret_graph; } -REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module* m) { +REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { (void)py::class_>(*m, "ZipOperation_") - .def(py::init()); + .def(py::init()); })); } // namespace prim } // namespace mindspore diff --git a/mindspore/ccsrc/operator/composite/zip_operation.h b/mindspore/ccsrc/operator/composite/zip_operation.h index e1fb8d60cf..1a3fa1f5fe 100644 --- a/mindspore/ccsrc/operator/composite/zip_operation.h +++ b/mindspore/ccsrc/operator/composite/zip_operation.h @@ -42,15 +42,15 @@ using AbstractTuplePtr = abstract::AbstractTuplePtr; class ZipOperation : public MetaFuncGraph { public: - explicit ZipOperation(const std::string& name) : MetaFuncGraph(name) {} + explicit ZipOperation(const std::string &name) : MetaFuncGraph(name) {} ~ZipOperation() override = default; MS_DECLARE_PARENT(ZipOperation, MetaFuncGraph) - FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList& args_spec_list) override; - friend std::ostream& operator<<(std::ostream& os, const ZipOperation& op) { + FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; + friend std::ostream &operator<<(std::ostream &os, const ZipOperation &op) { os << op.name_; return os; } - friend bool operator==(const ZipOperation& lhs, const ZipOperation& rhs) { return lhs.name_ == rhs.name_; } + friend bool operator==(const ZipOperation &lhs, const ZipOperation &rhs) { return lhs.name_ == rhs.name_; } }; using ZipOperationPtr = std::shared_ptr; } // namespace prim diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index ffd331c6c3..9d5777641b 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -238,7 +238,7 @@ const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); -ValuePtr GetPythonOps(const std::string& op_name, const std::string& module_name) { +ValuePtr GetPythonOps(const std::string &op_name, const std::string &module_name) { py::object obj = parse::python_adapter::GetPyFn(module_name, op_name); ValuePtr node = nullptr; bool succ = parse::ConvertData(obj, &node); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index a6c614b494..4852e2345e 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -26,8 +26,8 @@ namespace mindspore { // namespace to support primitive operators namespace prim { -ValuePtr GetPythonOps(const std::string& op_name, - const std::string& module_name = "mindspore._extends.parse.standard_method"); +ValuePtr GetPythonOps(const std::string &op_name, + const std::string &module_name = "mindspore._extends.parse.standard_method"); // Arithmetic extern const PrimitivePtr kPrimScalarAdd; @@ -241,7 +241,7 @@ extern const PrimitivePtr kPrimVirtualDataset; class DoSignaturePrimitive : public Primitive { public: - explicit DoSignaturePrimitive(const std::string& name, const ValuePtr& function) + explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) : Primitive("S-Prim-" + name), function_(function) {} ~DoSignaturePrimitive() override = default; @@ -257,7 +257,7 @@ using DoSignaturePrimitivePtr = std::shared_ptr; class UnpackGraphPrimitive : public Primitive { public: - explicit UnpackGraphPrimitive(const std::string& name, const bool& with_sens, const bool& need_unpack_args) + explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) : Primitive("UnpackGraph"), with_sens_in_args_(with_sens), need_unpack_args_(need_unpack_args) {} ~UnpackGraphPrimitive() override = default; MS_DECLARE_PARENT(UnpackGraphPrimitive, Primitive) diff --git a/mindspore/ccsrc/operator/prim_to_function.cc b/mindspore/ccsrc/operator/prim_to_function.cc index bdfe48157c..733cdbdb73 100644 --- a/mindspore/ccsrc/operator/prim_to_function.cc +++ b/mindspore/ccsrc/operator/prim_to_function.cc @@ -54,7 +54,7 @@ PrimToFunction::PrimToFunction() {"scalar_sub", kPrimTypeTwoArgs}, {"scalar_floordiv", kPrimTypeTwoArgs}}) {} -bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const func) const { +bool PrimToFunction::GetFunction(const PrimitivePtr &prim, FunctionPtr *const func) const { bool result = false; if (func != nullptr) { @@ -79,7 +79,7 @@ bool PrimToFunction::GetFunction(const PrimitivePtr& prim, FunctionPtr* const fu return result; } -int PrimToFunction::GetPrimType(const PrimitivePtr& prim) const { +int PrimToFunction::GetPrimType(const PrimitivePtr &prim) const { MS_EXCEPTION_IF_NULL(prim); int prim_type = static_cast(kPrimTypeUnknown); diff --git a/mindspore/ccsrc/operator/prim_to_function.h b/mindspore/ccsrc/operator/prim_to_function.h index 71518e4057..285ab8d3ab 100644 --- a/mindspore/ccsrc/operator/prim_to_function.h +++ b/mindspore/ccsrc/operator/prim_to_function.h @@ -41,21 +41,21 @@ class PrimToFunction; class PrimToFunction { public: // Return a thread-safe singleton instance - static PrimToFunction& GetInstance() { + static PrimToFunction &GetInstance() { static PrimToFunction instance; return instance; } - PrimToFunction(const PrimToFunction&) = delete; - PrimToFunction& operator=(const PrimToFunction&) = delete; + PrimToFunction(const PrimToFunction &) = delete; + PrimToFunction &operator=(const PrimToFunction &) = delete; ~PrimToFunction() = default; // Get the args and return value for a primitive instance. - bool GetFunction(const PrimitivePtr& prim, FunctionPtr* func) const; + bool GetFunction(const PrimitivePtr &prim, FunctionPtr *func) const; private: PrimToFunction(); // Get the number of primitive arguments - int GetPrimType(const PrimitivePtr& prim) const; + int GetPrimType(const PrimitivePtr &prim) const; const std::unordered_map prim_func_type_map_; }; } // namespace prim diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.cc b/mindspore/ccsrc/optimizer/ad/adjoint.cc index 46746b3f44..ed89aba20e 100644 --- a/mindspore/ccsrc/optimizer/ad/adjoint.cc +++ b/mindspore/ccsrc/optimizer/ad/adjoint.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace ad { -Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller) +Adjoint::Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller) : primal_(primal), caller_(caller), dout_(nullptr) { if (k != nullptr) { k_ = k; @@ -43,13 +43,13 @@ Adjoint::Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphP AnfNodePtr Adjoint::k() { return k_; } -void Adjoint::RegisterKUser(const CNodePtr& user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } +void Adjoint::RegisterKUser(const CNodePtr &user, size_t index) { k_user_.emplace_back(std::make_pair(user, index)); } -void Adjoint::UpdateK(const AnfNodePtr& new_k) { +void Adjoint::UpdateK(const AnfNodePtr &new_k) { MS_EXCEPTION_IF_NULL(new_k); MS_LOG(DEBUG) << "Replace k " << k_->ToString() << " with " << new_k->ToString(); // In recursive case, it needs update. - for (auto& user : k_user_) { + for (auto &user : k_user_) { MS_LOG(DEBUG) << "Update k user " << user.first->ToString() << " " << user.second << " input with new_k" << new_k->ToString(); if (user.first->input(user.second) != k_) { @@ -65,11 +65,11 @@ AnfNodePtr Adjoint::primal() { return primal_; } AnfNodePtr Adjoint::dout() { return dout_hole_; } -void Adjoint::RegisterDoutUser(const CNodePtr& user, size_t index) { +void Adjoint::RegisterDoutUser(const CNodePtr &user, size_t index) { dout_user_.emplace_back(std::make_pair(user, index)); } -void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { +void Adjoint::AccumulateDout(const AnfNodePtr &dout_factor) { if (dout_ != nullptr) { MS_LOG(DEBUG) << "Update dout " << dout_->ToString() << " with dout_factor " << dout_factor->ToString(); auto add = prim::GetPythonOps("hyper_add"); @@ -81,7 +81,7 @@ void Adjoint::AccumulateDout(const AnfNodePtr& dout_factor) { void Adjoint::CallDoutHole() { if (dout_ != nullptr) { - for (auto& user : dout_user_) { + for (auto &user : dout_user_) { MS_LOG(DEBUG) << "Update dout user " << user.first->ToString() << " " << user.second << " input with dout " << dout_->ToString(); if (user.first->input(user.second) != dout_hole_) { diff --git a/mindspore/ccsrc/optimizer/ad/adjoint.h b/mindspore/ccsrc/optimizer/ad/adjoint.h index 673928129b..b2dae8e66f 100644 --- a/mindspore/ccsrc/optimizer/ad/adjoint.h +++ b/mindspore/ccsrc/optimizer/ad/adjoint.h @@ -28,15 +28,15 @@ namespace mindspore { namespace ad { class Adjoint { public: - Adjoint(const AnfNodePtr& primal, const AnfNodePtr& k, const FuncGraphPtr& caller); + Adjoint(const AnfNodePtr &primal, const AnfNodePtr &k, const FuncGraphPtr &caller); ~Adjoint() = default; AnfNodePtr primal(); AnfNodePtr k(); - void UpdateK(const AnfNodePtr& k); - void RegisterKUser(const CNodePtr& user, size_t index); + void UpdateK(const AnfNodePtr &k); + void RegisterKUser(const CNodePtr &user, size_t index); AnfNodePtr dout(); - void AccumulateDout(const AnfNodePtr& dout_factor); - void RegisterDoutUser(const CNodePtr& user, size_t index); + void AccumulateDout(const AnfNodePtr &dout_factor); + void RegisterDoutUser(const CNodePtr &user, size_t index); void CallDoutHole(); private: diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc index 9e713d3425..fe11191546 100644 --- a/mindspore/ccsrc/optimizer/clean.cc +++ b/mindspore/ccsrc/optimizer/clean.cc @@ -36,7 +36,7 @@ using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractTuple; -static AbstractBasePtr Reabs(const AbstractBasePtr& t) { +static AbstractBasePtr Reabs(const AbstractBasePtr &t) { if (t == nullptr) { return nullptr; } @@ -47,14 +47,14 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { AbstractBasePtrList baselist; auto attributes = abs_class->attributes(); (void)std::transform(attributes.begin(), attributes.end(), std::back_inserter(baselist), - [](const AbstractAttribute& item) { return item.second; }); + [](const AbstractAttribute &item) { return item.second; }); res = std::make_shared(baselist); } else if (t->isa()) { auto abs_dict = dyn_cast(t); AbstractBasePtrList baselist; auto elements = abs_dict->elements(); (void)std::transform(elements.begin(), elements.end(), std::back_inserter(baselist), - [](const AbstractAttribute& item) { return item.second; }); + [](const AbstractAttribute &item) { return item.second; }); res = std::make_shared(baselist); } else if (t->isa()) { auto abs_dict = dyn_cast(t); @@ -63,11 +63,11 @@ static AbstractBasePtr Reabs(const AbstractBasePtr& t) { return res; } -AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [getattr, data, attribute] MS_ASSERT(inputs.size() == 3 && "GetAttr should have three inputs."); @@ -86,9 +86,9 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); - const auto& cmap = ct->attributes(); + const auto &cmap = ct->attributes(); int count = 0; - for (auto& item : cmap) { + for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } @@ -102,12 +102,12 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } -AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); // Inputs should be [dict_getitem, dict, item] - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); MS_ASSERT(inputs.size() == 3 && "DictGetItem should have three inputs."); AnfNodePtr data = inputs[1]; @@ -124,9 +124,9 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { auto cons_str = cons_is_str ? GetValue(GetValueNode(cons)) : ""; auto ct = dyn_cast(dt); - const auto& cmap = ct->elements(); + const auto &cmap = ct->elements(); int count = 0; - for (auto& item : cmap) { + for (auto &item : cmap) { if (cons_is_str && item.first == cons_str) { break; } @@ -139,7 +139,7 @@ AnfNodePtr ConvertDictGetItemToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, idx_c}); } -AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { +AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); @@ -150,11 +150,11 @@ AnfNodePtr ConvertMakeRecordToMakeTuple(const CNodePtr& node) { return node->func_graph()->NewCNode(inputs); } -AnfNodePtr ErasePartialNode(const CNodePtr& node) { +AnfNodePtr ErasePartialNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [partial, fn, arg1, ...], so offset by 2 to get arg; MS_ASSERT(inputs.size() >= 2 && "Partial should have more than two inputs."); @@ -178,7 +178,7 @@ AnfNodePtr ErasePartialNode(const CNodePtr& node) { return nullptr; } -AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { +AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); @@ -189,11 +189,11 @@ AnfNodePtr ConvertMakeListToMakeTuple(const CNodePtr& node) { return node->func_graph()->NewCNode(inputs); } -AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { +AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [list_getitem, list, item] if (inputs.size() < 3) { MS_LOG(EXCEPTION) << "Node's input number < 3."; @@ -208,11 +208,11 @@ AnfNodePtr ConvertListGetItemToTupleGetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleGetItem), data, cons_node}); } -AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { +AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node->func_graph()); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [list_setitem, list, index, item] if (inputs.size() < 4) { MS_LOG(EXCEPTION) << "Node's input number < 4."; @@ -225,36 +225,36 @@ AnfNodePtr ConvertListSetItemToTupleSetItem(const CNodePtr& node) { return node->func_graph()->NewCNode({NewValueNode(prim::kPrimTupleSetItem), data, cons, value}); } -AnfNodePtr EraseMakeDictNode(const CNodePtr& node) { +AnfNodePtr EraseMakeDictNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); MS_ASSERT(inputs.size() >= 3 && "MakeDict should have three inputs"); return inputs[2]; } -AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr& node) { +AnfNodePtr EraseMakeKeywordArgNode(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [make_keyword_arg, key, value] MS_ASSERT(inputs.size() == 3 && "MakeKeyword should have three inputs"); return inputs[2]; } -AnfNodePtr EraseExtractKeywordArg(const CNodePtr& node) { +AnfNodePtr EraseExtractKeywordArg(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - const auto& inputs = node->inputs(); + const auto &inputs = node->inputs(); // Inputs should be [extract_keyword_arg, arg, key] MS_ASSERT(inputs.size() == 3 && "ExtractKeyword should have three inputs"); return inputs[2]; } -ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int depth) { +ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr &value_list, int depth) { const int DEPTH_MAX = 5; if (depth > DEPTH_MAX) { MS_LOG(EXCEPTION) << "List nesting is not allowed more than 5 levels."; } std::vector elements; - for (const auto& it : value_list->value()) { + for (const auto &it : value_list->value()) { ValuePtr value = nullptr; if (it->isa()) { value = ConvertValueListToValueTuple(it->cast(), depth + 1); @@ -266,7 +266,7 @@ ValueTuplePtr ConvertValueListToValueTuple(const ValueListPtr& value_list, int d return std::make_shared(elements); } -AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { +AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { MS_EXCEPTION_IF_NULL(node); ValuePtr value = node->value(); auto value_list = value->cast(); @@ -278,13 +278,13 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr& node) { // Convert class to Tuple // Convert getattr to getitem // Convert make_record to make_tuple -void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { +void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); - for (auto& node : all_node) { + for (auto &node : all_node) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); AnfNodePtr new_node = nullptr; @@ -320,20 +320,20 @@ void SimplifyDataStructures(const FuncGraphPtr& root, const FuncGraphManagerPtr& } } - for (auto& node : manager->all_nodes()) { + for (auto &node : manager->all_nodes()) { auto ret = Reabs(node->abstract()); node->set_abstract(ret); } } // expand tuples in graph parameters -static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, const FuncGraphPtr& func_graph, - const std::vector& params) { +static std::vector ExpandTuplesP(const FuncGraphManagerPtr &mng, const FuncGraphPtr &func_graph, + const std::vector ¶ms) { MS_EXCEPTION_IF_NULL(mng); MS_EXCEPTION_IF_NULL(func_graph); std::vector new_params; - for (const auto& param : params) { + for (const auto ¶m : params) { MS_EXCEPTION_IF_NULL(param); auto param_abs = param->abstract(); MS_EXCEPTION_IF_NULL(param_abs); @@ -350,7 +350,7 @@ static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, con std::vector new_param; std::vector inputs{NewValueNode(prim::kPrimMakeTuple)}; auto abs_tuple = dyn_cast(param_abs); - for (auto& elem : abs_tuple->elements()) { + for (auto &elem : abs_tuple->elements()) { auto np = std::make_shared(func_graph); np->set_abstract(elem); new_param.emplace_back(np); @@ -366,11 +366,11 @@ static std::vector ExpandTuplesP(const FuncGraphManagerPtr& mng, con } // expand tuples in graph applies -static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const std::vector& inputs) { +static std::vector ExpandTuplesC(const FuncGraphPtr &graph, const std::vector &inputs) { MS_EXCEPTION_IF_NULL(graph); std::vector new_inputs; - for (const auto& input : inputs) { + for (const auto &input : inputs) { MS_EXCEPTION_IF_NULL(input); auto input_abs = input->abstract(); @@ -391,7 +391,7 @@ static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const st int idx = 0; std::vector new_input; auto abs_tuple = dyn_cast(input_abs); - for (auto& elem : abs_tuple->elements()) { + for (auto &elem : abs_tuple->elements()) { auto c_node = graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), input, NewValueNode(idx)}); AbstractBasePtr aptr = std::make_shared(std::make_shared(idx)); c_node->input(2)->set_abstract(aptr); @@ -416,19 +416,19 @@ static std::vector ExpandTuplesC(const FuncGraphPtr& graph, const st // tuples in Graph's parameters: AbstractTuple (a, b, c) --> // CNode("make_tuple", Parameter(a), Parameter(b), Parameter(c)) // cppcheck-suppress unusedFunction -void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { +void EraseTuple(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); // NOTICE: since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); - for (auto& node : all_node) { + for (auto &node : all_node) { auto cnode = node->cast(); if (cnode == nullptr) { continue; } - const auto& inputs = cnode->inputs(); + const auto &inputs = cnode->inputs(); // Bypass the first input in inputs as it's fn. if (!IsValueNode(inputs[0])) { @@ -466,7 +466,7 @@ void EraseTuple(const FuncGraphPtr& root, const FuncGraphManagerPtr& manager) { } FuncGraphSet all_graph = manager->func_graphs(); - for (auto& func_graph : all_graph) { + for (auto &func_graph : all_graph) { MS_EXCEPTION_IF_NULL(func_graph); auto expand_p = ExpandTuplesP(manager, func_graph, func_graph->parameters()); manager->SetParameters(func_graph, expand_p); diff --git a/mindspore/ccsrc/optimizer/control_depend.h b/mindspore/ccsrc/optimizer/control_depend.h index 2a51a24718..076e2c0229 100644 --- a/mindspore/ccsrc/optimizer/control_depend.h +++ b/mindspore/ccsrc/optimizer/control_depend.h @@ -22,7 +22,7 @@ namespace mindspore { namespace opt { // Automatically adding control depend based on effect order and side effect analysis. -void AddControlDepend(const FuncGraphPtr& graph); +void AddControlDepend(const FuncGraphPtr &graph); } // namespace opt } // namespace mindspore #endif // MINDSPORE_CCSRC_OPTIMIZER_CONTROL_DEPEND_H_ diff --git a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc index 5daeced3a5..32a42bc16b 100644 --- a/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc +++ b/mindspore/ccsrc/optimizer/irpass/grad_var_prepare.cc @@ -44,7 +44,7 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func nodes.push_back(func_node); // {unpackcall, {GradOperation, ...}, args...} std::transform(inputs_y.begin() + 2, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr& node) { return node; }); + [](const AnfNodePtr &node) { return node; }); unpack_graph_node = func_graph->NewCNode(nodes); } else { auto unpack_graph = std::make_shared("unpack_graph", sens_param, false); @@ -52,14 +52,14 @@ static AnfNodePtr GenerateUnpackGraphNode(std::vector inputs_y, Func nodes.push_back(func_node); // {{GradOperation, ...}, args...} std::transform(inputs_y.begin() + 1, inputs_y.end(), std::back_inserter(nodes), - [](const AnfNodePtr& node) { return node; }); + [](const AnfNodePtr &node) { return node; }); unpack_graph_node = func_graph->NewCNode(nodes); } return unpack_graph_node; } // get metagraph of value node -MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { +MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) { ValuePtr value; if (IsValueNode(node)) { value = GetValueNode(node)->cast()->function(); @@ -73,7 +73,7 @@ MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr& node) { } // check if node is a specific metafuncgraph op -bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_graph) { +bool IsMetaFuncGraph(const AnfNodePtr &node, const MetaFuncGraphPtr meta_func_graph) { if (node != nullptr) { auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node); if (meta_func_graph_ptr == nullptr) { @@ -89,7 +89,7 @@ bool IsMetaFuncGraph(const AnfNodePtr& node, const MetaFuncGraphPtr meta_func_gr // {{GradOperation, g, w}, Ys} // {UnPackCall, {GradOperation, g, w}, Ys} -AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr&, const AnfNodePtr& node) { +AnfNodePtr GradVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &node) { if (!node->isa() || node->func_graph() == nullptr) { return nullptr; } diff --git a/mindspore/ccsrc/optimizer/opt.cc b/mindspore/ccsrc/optimizer/opt.cc index 24339ddb84..0dbaf1107f 100644 --- a/mindspore/ccsrc/optimizer/opt.cc +++ b/mindspore/ccsrc/optimizer/opt.cc @@ -31,20 +31,20 @@ namespace mindspore { /* namespace to support opt */ namespace opt { -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, const PrimitivePtr& prim, - const RenormAction& renorm_action) { - auto fn = [prim](const AnfNodePtr& node) -> bool { return IsPrimitiveCNode(node, prim); }; +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, const PrimitivePtr &prim, + const RenormAction &renorm_action) { + auto fn = [prim](const AnfNodePtr &node) -> bool { return IsPrimitiveCNode(node, prim); }; return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const std::vector& prims, const RenormAction& renorm_action) { - auto fn = [prims](const AnfNodePtr& node) -> bool { +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, + const std::vector &prims, const RenormAction &renorm_action) { + auto fn = [prims](const AnfNodePtr &node) -> bool { if (!node->isa()) { return false; } - for (auto& prim : prims) { + for (auto &prim : prims) { if (IsPrimitiveCNode(node, prim)) { return true; } @@ -55,12 +55,12 @@ SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std:: return std::make_shared(transform, name, fn, renorm_action); } -SubstitutionPtr MakeSubstitution(const TransformFuncType& transform, const std::string& name, - const PredicateFuncType& predicate, const RenormAction& renorm_action) { +SubstitutionPtr MakeSubstitution(const TransformFuncType &transform, const std::string &name, + const PredicateFuncType &predicate, const RenormAction &renorm_action) { return std::make_shared(transform, name, predicate, renorm_action); } -AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNodePtr& node) const { +AnfNodePtr Substitution::operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) const { #ifdef ENABLE_PROFILE double t = GetTime(); #endif @@ -88,8 +88,8 @@ AnfNodePtr Substitution::operator()(const OptimizerPtr& optimizer, const AnfNode return result; } -bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNodePtr& root_node, - const SubstitutionPtr& transform) const { +bool SubstitutionList::ApplyTransform(const OptimizerPtr &optimizer, const AnfNodePtr &root_node, + const SubstitutionPtr &transform) const { FuncGraphManagerPtr manager = optimizer->manager(); std::unordered_set seen_node; std::deque todo{root_node}; @@ -131,13 +131,13 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo } if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo)); } - auto& node_users = manager->node_users(); + auto &node_users = manager->node_users(); if (change && node_users.find(node) != node_users.end()) { - for (auto& use : node_users[node]) { + for (auto &use : node_users[node]) { auto use_node = use.first; todo.push_back(use_node); if (seen_node.find(use_node) != seen_node.end()) { @@ -152,7 +152,7 @@ bool SubstitutionList::ApplyTransform(const OptimizerPtr& optimizer, const AnfNo return changes; } -bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const OptimizerPtr& optimizer) const { +bool SubstitutionList::operator()(const FuncGraphPtr &func_graph, const OptimizerPtr &optimizer) const { MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = optimizer->manager(); @@ -163,7 +163,7 @@ bool SubstitutionList::operator()(const FuncGraphPtr& func_graph, const Optimize do { loop = false; - for (auto const& transform : list_) { + for (auto const &transform : list_) { auto change = ApplyTransform(optimizer, func_graph->output(), transform); changes = changes || change; loop = loop || change; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc index 03f7d054e0..b4f4cb5b22 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t recursive_times = 0) { +std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t recursive_times = 0) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindCNodesWithPara exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -39,7 +39,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t MS_EXCEPTION_IF_NULL(manager); auto node_set = manager->node_users()[para]; std::unordered_set cnode_set; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { auto cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -54,7 +54,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t (void)cnode_set.emplace(cnode); } else { auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); - for (auto& cnode_sub : cnode_set_sub) { + for (auto &cnode_sub : cnode_set_sub) { (void)cnode_set.emplace(cnode_sub); } } @@ -63,8 +63,8 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr& para, uint32_t } Status AllreduceFusion::AddNodeToGraph() { - const auto& parameters = root_graph_->parameters(); - for (auto& parameter : parameters) { + const auto ¶meters = root_graph_->parameters(); + for (auto ¶meter : parameters) { if (!ParameterRequireGrad(parameter)) { continue; } @@ -72,7 +72,7 @@ Status AllreduceFusion::AddNodeToGraph() { if (cnode_set.empty()) { continue; } - for (auto& cnode : cnode_set) { + for (auto &cnode : cnode_set) { MS_LOG(DEBUG) << "AddNode " << cnode->DebugString(); if (allreduce_graph_.AddNode(cnode, parameter) != SUCCESS) { MS_LOG(ERROR) << "AddNode failed! cnode: " << cnode->DebugString(); @@ -83,7 +83,7 @@ Status AllreduceFusion::AddNodeToGraph() { return SUCCESS; } -CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursive_times) const { +CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursive_times) const { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindCNode exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -110,30 +110,30 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr& from, uint32_t recursi return cnode_dist; } else { auto cnode_dist_next = FindNextCNodes(cnode, recursive_times + 1); - for (auto& ele : cnode_dist_next) { + for (auto &ele : cnode_dist_next) { cnode_dist[ele.first] = cost + ele.second; } } } else { auto cnode_dist_next = FindNextCNodes(cnode); - for (auto& ele : cnode_dist_next) { + for (auto &ele : cnode_dist_next) { cnode_dist[ele.first] = ele.second; } } return cnode_dist; } -CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recursive_times) const { +CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr &from, uint32_t recursive_times) const { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindNextCNodes exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; } - const auto& from_inputs = from->inputs(); + const auto &from_inputs = from->inputs(); std::unordered_map dist_map; MS_LOG(DEBUG) << "from cnode " << from->DebugString() << " has " << from_inputs.size() << " inputs"; - for (auto& input_node : from_inputs) { + for (auto &input_node : from_inputs) { auto cnode_dist = FindCNode(input_node, recursive_times + 1); - for (auto& ele : cnode_dist) { + for (auto &ele : cnode_dist) { (void)dist_map.emplace(ele); } } @@ -142,11 +142,11 @@ CNodeCostMap AllreduceFusion::FindNextCNodes(const CNodePtr& from, uint32_t recu Status AllreduceFusion::AddEdgeToGraph() { std::unordered_map cnode_state_map; - const auto& cnodes = allreduce_graph_.cnode_set(); - for (auto& cnode : cnodes) { + const auto &cnodes = allreduce_graph_.cnode_set(); + for (auto &cnode : cnodes) { cnode_state_map[cnode] = 0; } - const auto& head_cnode = allreduce_graph_.head_cnode(); + const auto &head_cnode = allreduce_graph_.head_cnode(); std::queue cnode_queue; cnode_queue.emplace(head_cnode); cnode_state_map[head_cnode] = 1; @@ -156,9 +156,9 @@ Status AllreduceFusion::AddEdgeToGraph() { cnode_queue.pop(); cnode_state_map[cur_cnode] = 2; auto next = FindNextCNodes(cur_cnode); - for (auto& ele : next) { - auto& cnode = ele.first; - auto& dist = ele.second; + for (auto &ele : next) { + auto &cnode = ele.first; + auto &dist = ele.second; if (cnode_state_map[cnode] == 0) { cnode_queue.emplace(cnode); cnode_state_map[cnode] = 1; @@ -173,7 +173,7 @@ Status AllreduceFusion::AddEdgeToGraph() { return SUCCESS; } -std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_times = 0) { +std::vector FindMirror(const AnfNodePtr ¶, uint32_t recursive_times = 0) { if (recursive_times > MAX_RECURSIVE_CALL_TIMES) { MS_LOG(EXCEPTION) << "FindMirror exceeds max recursive call times! Max recursive call times is " << MAX_RECURSIVE_CALL_TIMES; @@ -184,7 +184,7 @@ std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_time MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[para]; std::vector cnode_list; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { auto cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -210,7 +210,7 @@ std::vector FindMirror(const AnfNodePtr& para, uint32_t recursive_time return cnode_list; } -void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::string& parameter_name) { +void SetMirrorFusion(const CNodePtr &mirror_cnode, int32_t fusion, const std::string ¶meter_name) { MS_EXCEPTION_IF_NULL(mirror_cnode); MS_LOG(DEBUG) << "Set Mirror " << mirror_cnode->DebugString() << " fusion " << fusion; auto node_prim = GetValueNode(mirror_cnode->input(0)); @@ -227,14 +227,14 @@ void SetMirrorFusion(const CNodePtr& mirror_cnode, int32_t fusion, const std::st (void)node_prim->AddAttr(PARAMETER, MakeValue(std::make_shared(parameter_name))); } -Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { +Status FindMirrorAndSetFusion(const AnfNodePtr ¶, int32_t fusion) { auto mirror_cnodes = FindMirror(para); if (mirror_cnodes.empty()) { MS_LOG(WARNING) << para->ToString() << " 0 Mirror CNode found."; return SUCCESS; } if (mirror_cnodes.size() > 2) { - for (auto& mirror_cnode : mirror_cnodes) { + for (auto &mirror_cnode : mirror_cnodes) { MS_EXCEPTION_IF_NULL(mirror_cnode); MS_LOG(INFO) << mirror_cnode->DebugString(); } @@ -243,15 +243,15 @@ Status FindMirrorAndSetFusion(const AnfNodePtr& para, int32_t fusion) { << "Mirror CNode found."; return FAILED; } - for (auto& mirror_cnode : mirror_cnodes) { + for (auto &mirror_cnode : mirror_cnodes) { auto parameter_name = ParameterName(para); SetMirrorFusion(mirror_cnode, fusion, parameter_name); } return SUCCESS; } -Status FindMirrorAndSetFusion(const std::vector& paras, int32_t fusion) { - for (auto& param_node : paras) { +Status FindMirrorAndSetFusion(const std::vector ¶s, int32_t fusion) { + for (auto ¶m_node : paras) { if (FindMirrorAndSetFusion(param_node, fusion) != SUCCESS) { MS_LOG(ERROR) << "FindMirrorAndSetFusion failed"; return FAILED; @@ -260,7 +260,7 @@ Status FindMirrorAndSetFusion(const std::vector& paras, int32_t fusi return SUCCESS; } -Status AllreduceFusion::SetFusion(const std::vector& cost_map) { +Status AllreduceFusion::SetFusion(const std::vector &cost_map) { if (cost_map.size() < 2) { MS_LOG(ERROR) << "cost_map must has at least 2 items, cost_map size is " << cost_map.size(); return FAILED; @@ -386,7 +386,7 @@ Status AllreduceFusion::SetFusionByAlgorithm(int32_t algorithm) { return SetFusionByBackwardCompAndAllreduceTime(); } -Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr& ret) { +Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { if (ret == nullptr) { MS_LOG(ERROR) << "ret is nullptr."; return FAILED; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h index 67dc55836a..43a9935095 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.h @@ -50,15 +50,15 @@ class AllreduceFusion { allreduce_bandwidth_(0), computation_time_parameter_(0) {} virtual ~AllreduceFusion() = default; - Status ProcessAllreduceFusion(const CNodePtr& ret); + Status ProcessAllreduceFusion(const CNodePtr &ret); private: Status AddNodeToGraph(); - CNodeCostMap FindCNode(const AnfNodePtr& from, uint32_t recursive_times = 0) const; - CNodeCostMap FindNextCNodes(const CNodePtr& from, uint32_t recursive_times = 0) const; + CNodeCostMap FindCNode(const AnfNodePtr &from, uint32_t recursive_times = 0) const; + CNodeCostMap FindNextCNodes(const CNodePtr &from, uint32_t recursive_times = 0) const; Status AddEdgeToGraph(); std::vector GenerateCostMap(int32_t fusion_times, double tail_percent) const; - Status SetFusion(const std::vector& cost_map); + Status SetFusion(const std::vector &cost_map); Status SetFusionByAlgorithm(int32_t algorithm); Status SetFusionByBackwardCompTime(); Status SetFusionByBackwardCompAndAllreduceTime(); diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc index 9e04593c83..2a98a38add 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.cc @@ -23,7 +23,7 @@ namespace mindspore { namespace parallel { -Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { +Status AllreduceGraph::AddNode(const CNodePtr &node, const AnfNodePtr ¶) { AllreduceNodePtr arnode; auto cnode_emplace_return = cnode_set_.emplace(node); if (!cnode_emplace_return.second) { @@ -64,7 +64,7 @@ Status AllreduceGraph::AddNode(const CNodePtr& node, const AnfNodePtr& para) { return SUCCESS; } -Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double dist) { +Status AllreduceGraph::AddEdge(const CNodePtr &from, const CNodePtr &to, double dist) { auto from_arnode_iter = cnode_arnode_map_.find(from); if (from_arnode_iter == cnode_arnode_map_.end()) { MS_LOG(ERROR) << "cnode from: " << from->DebugString() << "has not been added"; @@ -94,14 +94,14 @@ Status AllreduceGraph::AddEdge(const CNodePtr& from, const CNodePtr& to, double return SUCCESS; } -bool AllreduceGraph::NodeInGraph(const CNodePtr& node) const { +bool AllreduceGraph::NodeInGraph(const CNodePtr &node) const { auto cnode_iter = cnode_set_.find(node); return !(cnode_iter == cnode_set_.end()); } std::vector AllreduceGraph::GetParaByCost(double from, double to) { std::vector nodes; - for (auto& cnode_arnode : cnode_arnode_map_) { + for (auto &cnode_arnode : cnode_arnode_map_) { MS_LOG(DEBUG) << "cnode: " << cnode_arnode.first->DebugString() << ", depend_feat_size: " << cnode_arnode.second->depend_feat_size() << " curr_para_size: " << cnode_arnode.second->curr_para_size(); @@ -117,7 +117,7 @@ std::pair, double> AllreduceGraph::GetParaByParaSize(dou std::vector nodes; double cur_para_size = 0; double from = to; - for (auto& arnode : arnode_vec_) { + for (auto &arnode : arnode_vec_) { if (arnode.depend_feat_size() != max_ && arnode.depend_feat_size() >= to) { continue; } @@ -135,14 +135,14 @@ std::pair, double> AllreduceGraph::GetParaByParaSize(dou void AllreduceGraph::PrintCNodeSet() const { MS_LOG(INFO) << "CNodeSet:"; - for (auto& cnode : cnode_set_) { + for (auto &cnode : cnode_set_) { MS_LOG(INFO) << cnode->DebugString(); } } void AllreduceGraph::PrintAllredueGraphInfo() const { MS_LOG(INFO) << "max: " << max_; - for (auto& cnode_arnode : cnode_arnode_map_) { + for (auto &cnode_arnode : cnode_arnode_map_) { MS_LOG(INFO) << "cnode: " << cnode_arnode.first->DebugString(); MS_LOG(INFO) << "arnode info: "; cnode_arnode.second->ToString(); @@ -151,21 +151,21 @@ void AllreduceGraph::PrintAllredueGraphInfo() const { void AllreduceGraph::PrintArnodeVec() const { MS_LOG(INFO) << "ArnodeVec:"; - for (auto& arnode : arnode_vec_) { + for (auto &arnode : arnode_vec_) { arnode.ToString(); } } void AllreduceGraph::PrintArnodeSet() const { MS_LOG(INFO) << "ArnodeSet:"; - for (auto& arnode : arnode_set_) { + for (auto &arnode : arnode_set_) { arnode->ToString(); } } void AllreduceGraph::SortArnode() { arnode_vec_.clear(); - for (auto& node : arnode_set_) { + for (auto &node : arnode_set_) { arnode_vec_.emplace_back(*node); } std::sort(arnode_vec_.begin(), arnode_vec_.end(), std::greater<>()); @@ -173,8 +173,8 @@ void AllreduceGraph::SortArnode() { Status AllreduceGraph::RemoveExtraParas() { std::unordered_set para_map; - for (auto& node : arnode_vec_) { - for (auto& para : node.paras()) { + for (auto &node : arnode_vec_) { + for (auto ¶ : node.paras()) { auto emplac_result = para_map.emplace(para); if (!emplac_result.second) { MS_LOG(DEBUG) << "parameter: " << para->fullname_with_scope() << "in arnode"; @@ -188,7 +188,7 @@ Status AllreduceGraph::RemoveExtraParas() { return SUCCESS; } -Status AllreduceGraph::set_head_cnode(const CNodePtr& node) { +Status AllreduceGraph::set_head_cnode(const CNodePtr &node) { auto arnode = std::make_shared(AllreduceNode()); if (arnode->Init(node) != SUCCESS) { MS_LOG(ERROR) << "AllreduceNode Init failed"; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h index f0db78a130..b2084b735c 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_graph.h @@ -42,9 +42,9 @@ class AllreduceGraph { cnode_arnode_map_(), max_(0) {} virtual ~AllreduceGraph() = default; - Status AddNode(const CNodePtr& node, const AnfNodePtr& para); - Status AddEdge(const CNodePtr& from, const CNodePtr& to, double dist); - bool NodeInGraph(const CNodePtr& node) const; + Status AddNode(const CNodePtr &node, const AnfNodePtr ¶); + Status AddEdge(const CNodePtr &from, const CNodePtr &to, double dist); + bool NodeInGraph(const CNodePtr &node) const; std::vector GetParaByCost(double from, double to); // Find the first several AllreduceNode whose depend_feat_size is less than to, the sum of whose parameter size is // over para_size. @@ -60,9 +60,9 @@ class AllreduceGraph { void PrintAllredueGraphInfo() const; void PrintArnodeVec() const; void PrintArnodeSet() const; - const std::unordered_set& cnode_set() const { return cnode_set_; } + const std::unordered_set &cnode_set() const { return cnode_set_; } CNodePtr head_cnode() const { return head_cnode_; } - Status set_head_cnode(const CNodePtr& node); + Status set_head_cnode(const CNodePtr &node); double max() const { return max_; } private: diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc index 6be588928a..113d4ec59b 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -Status AllreduceNode::AddNext(const AllreduceNodePtr& next_node) { +Status AllreduceNode::AddNext(const AllreduceNodePtr &next_node) { if (next_node == nullptr) { MS_LOG(ERROR) << "next_node is nullptr!"; return FAILED; @@ -30,7 +30,7 @@ Status AllreduceNode::AddNext(const AllreduceNodePtr& next_node) { return SUCCESS; } -Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max) { +Status AllreduceNode::AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max) { if (prev_node == nullptr) { MS_LOG(ERROR) << "next_node is nullptr!"; return FAILED; @@ -46,7 +46,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do *max = depend_feat_size_; } std::queue next_queue; - for (auto& next : next_) { + for (auto &next : next_) { next_queue.push(next); } while (!next_queue.empty()) { @@ -55,7 +55,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do if (ele->depend_feat_size() > *max) { *max = ele->depend_feat_size(); } - for (auto& next : ele->next()) { + for (auto &next : ele->next()) { next_queue.push(next); } next_queue.pop(); @@ -63,7 +63,7 @@ Status AllreduceNode::AddPrev(const AllreduceNodePtr& prev_node, double dist, do return SUCCESS; } -Status AllreduceNode::Init(const CNodePtr& cnode_ptr) { +Status AllreduceNode::Init(const CNodePtr &cnode_ptr) { if (cnode_ptr == nullptr) { MS_LOG(ERROR) << "cnode_ptr is nullptr!"; return FAILED; @@ -72,7 +72,7 @@ Status AllreduceNode::Init(const CNodePtr& cnode_ptr) { return SUCCESS; } -Status AllreduceNode::AddPara(const AnfNodePtr& node_ptr) { +Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { if (node_ptr == nullptr) { MS_LOG(ERROR) << "node_ptr is nullptr!"; return FAILED; @@ -99,7 +99,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr& node_ptr) { return SUCCESS; } -Status AllreduceNode::RemovePara(const AnfNodePtr& node_ptr) { +Status AllreduceNode::RemovePara(const AnfNodePtr &node_ptr) { if (node_ptr == nullptr) { MS_LOG(ERROR) << "node_ptr is nullptr!"; return FAILED; @@ -115,7 +115,7 @@ Status AllreduceNode::RemovePara(const AnfNodePtr& node_ptr) { void AllreduceNode::ToString() const { MS_LOG(INFO) << "cnode: " << cnode_ptr_->DebugString() << "para size: " << paras_.size(); - for (auto& para : paras_) { + for (auto ¶ : paras_) { MS_LOG(INFO) << "para name: " << para->fullname_with_scope() << " size: " << para_size_map_.at(para); } MS_LOG(INFO) << "depend_feat_size: " << depend_feat_size_ << " curr_para_size: " << curr_para_size_; diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h index d9ba98c3a2..db1c4e3f2e 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_node.h @@ -33,23 +33,23 @@ class AllreduceNode { public: AllreduceNode() : cnode_ptr_(nullptr), prev_(), next_(), paras_(), para_size_map_(), curr_para_size_(0), depend_feat_size_(0) {} - Status Init(const CNodePtr& cnode_ptr); - Status AddPara(const AnfNodePtr& node_ptr); - Status RemovePara(const AnfNodePtr& node_ptr); - const std::unordered_set& paras() const { return paras_; } + Status Init(const CNodePtr &cnode_ptr); + Status AddPara(const AnfNodePtr &node_ptr); + Status RemovePara(const AnfNodePtr &node_ptr); + const std::unordered_set ¶s() const { return paras_; } double curr_para_size() const { return curr_para_size_; } virtual ~AllreduceNode() = default; // Add previous node // prev_node is the previous to be added // max is the current max depend_feat_size of the AllreduceGraph - Status AddPrev(const AllreduceNodePtr& prev_node, double dist, double* max); - Status AddNext(const AllreduceNodePtr& next_node); + Status AddPrev(const AllreduceNodePtr &prev_node, double dist, double *max); + Status AddNext(const AllreduceNodePtr &next_node); double depend_feat_size() const { return depend_feat_size_; } void AddDependFeatSize(double add_dist) { depend_feat_size_ += add_dist; } - const std::vector& next() const { return next_; } + const std::vector &next() const { return next_; } void ToString() const; - bool operator<(const AllreduceNode& node) const { return depend_feat_size_ < node.depend_feat_size(); } - bool operator>(const AllreduceNode& node) const { return depend_feat_size_ > node.depend_feat_size(); } + bool operator<(const AllreduceNode &node) const { return depend_feat_size_ < node.depend_feat_size(); } + bool operator>(const AllreduceNode &node) const { return depend_feat_size_ > node.depend_feat_size(); } private: CNodePtr cnode_ptr_; diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc index 190f589bb5..ad3a3a1298 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -void Simplify(CostPtrList* clist_ptrs) { +void Simplify(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_cost decreasing order. This method // excludes the cost with greater computation_cost_ and greater communication_cost. // E.g. clist_ptrs = {<100, 20>, <200, 10>, <300, 50>}. After this method, clist_ptrs = {<200, 10>, <100, 20>} @@ -44,7 +44,7 @@ void Simplify(CostPtrList* clist_ptrs) { *clist_ptrs = std::move(ret); } -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { +void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist_ptrs) { // Sort the cost_list with the computation_cost_ increasing, and communication_with_partial_para_cost decreasing // order. This method excludes the cost with greater computation_cost_ and greater communication_without_para_cost. if (!COST_MODEL_SIMPLIFY_CALCULATION) { @@ -66,7 +66,7 @@ void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist_ptrs) { *clist_ptrs = std::move(ret); } -void RefineForPracticalCost(const CostPtr& origin_cost, bool is_redistribution) { +void RefineForPracticalCost(const CostPtr &origin_cost, bool is_redistribution) { MS_EXCEPTION_IF_NULL(origin_cost); if (is_redistribution) { // Redistribution cost diff --git a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h index 9e9003848b..2cb24dd7f3 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/costmodel.h @@ -44,7 +44,7 @@ using RedistributionOpListPtr = std::shared_ptr& decision_ = nullptr) + Cost(double computation, double commuication, const std::shared_ptr &decision_ = nullptr) : computation_cost_(computation), communication_cost_(commuication), decision_ptr_(std::move(decision_)) { memory_with_reuse_ = 0.0; communication_without_parameter_ = 0.0; @@ -76,8 +76,8 @@ class StrategyWithCost { StrategyWithCost(StrategyPtr strategy, std::vector inputs_, std::vector outputs_) : strategy_ptr(std::move(strategy)), inputs_ptr(std::move(inputs_)), outputs_ptr(std::move(outputs_)) {} - StrategyWithCost(const StrategyWithCost& swc) = delete; - StrategyWithCost(StrategyWithCost&& swc) + StrategyWithCost(const StrategyWithCost &swc) = delete; + StrategyWithCost(StrategyWithCost &&swc) : strategy_ptr(swc.strategy_ptr), inputs_ptr(swc.inputs_ptr), outputs_ptr(swc.outputs_ptr), @@ -295,9 +295,9 @@ using StarEliminationDecisionPtr = std::shared_ptr; using FinalDecisionPtr = std::shared_ptr; using FinalSingleDecisionPtr = std::shared_ptr; -void Simplify(CostPtrList* clist); -void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList* clist); -void RefineForPracticalCost(const CostPtr&, bool is_redistribution); +void Simplify(CostPtrList *clist); +void SimplifyForDreasingCommunicationWithPartialPara(CostPtrList *clist); +void RefineForPracticalCost(const CostPtr &, bool is_redistribution); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc index dd21096fcc..8d439f1522 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -Status GetStrategy(const CostGraphPtr& graph) { +Status GetStrategy(const CostGraphPtr &graph) { MS_LOG(INFO) << "Searching strategies begins."; MS_EXCEPTION_IF_NULL(graph); std::vector eliminations; @@ -141,7 +141,7 @@ Status RecoverStrategy(std::vector eliminations) { auto elimination = (*rit)->cast(); auto new_edge = elimination->new_edge_; MS_EXCEPTION_IF_NULL(new_edge); - auto& edges = elimination->edges_; + auto &edges = elimination->edges_; auto decision = new_edge->selected_cost()->decision_ptr_->cast(); for (size_t j = 0; j < edges.size(); ++j) { MS_EXCEPTION_IF_NULL(edges[j]); diff --git a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h index 6d43218e19..efedba7d10 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/dp_algo_costmodel.h @@ -65,7 +65,7 @@ struct OpElimination : public Elimination { // Edge Elimination struct EdgeElimination : public Elimination { - EdgeElimination(const EdgePtr& n_edge, std::vector eds) + EdgeElimination(const EdgePtr &n_edge, std::vector eds) : Elimination(n_edge, Elimination::EliminationType::EDGE), edges_(std::move(eds)) {} std::vector edges_; @@ -139,7 +139,7 @@ using TriangleEliminationPtr = std::shared_ptr; using StarEliminationPtr = std::shared_ptr; // Phase 1 and Phase 2 -Status GetStrategy(const CostGraphPtr& graph); +Status GetStrategy(const CostGraphPtr &graph); // Phase 3 Status RecoverStrategy(std::vector eliminations); diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc index 21e67f9f7b..6973830779 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.cc @@ -28,19 +28,19 @@ namespace mindspore { namespace parallel { Status Edge::InitEdgeCost() { bool has_available_cost = false; - for (auto& swc : prev_op_->GetStrategyCost()) { + for (auto &swc : prev_op_->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(swc); pre_op_output_.emplace_back(std::make_pair(swc->strategy_ptr, swc->outputs_ptr)); } - for (auto& swc : next_op_->GetStrategyCost()) { + for (auto &swc : next_op_->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(swc); next_op_input_.emplace_back(std::make_pair(swc->strategy_ptr, swc->inputs_ptr)); } if (is_identity_edge) { - for (auto& target_output : pre_op_output_) { + for (auto &target_output : pre_op_output_) { auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); auto target_output_str = target_output.first; - for (auto& target_input : next_op_input_) { + for (auto &target_input : next_op_input_) { auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); auto target_input_str = target_input.first; if (target_output_lyt == target_input_lyt) { @@ -57,12 +57,12 @@ Status Edge::InitEdgeCost() { } } } else { - for (auto& target_output : pre_op_output_) { + for (auto &target_output : pre_op_output_) { auto target_output_lyt = target_output.second[prev_op_output_index_].tensor_layout(); auto target_output_str = target_output.first; auto type_length = prev_op_->GetOutputTypeLengths()[prev_op_output_index_]; auto type = prev_op_->outputs_type()[prev_op_output_index_]; - for (auto& target_input : next_op_input_) { + for (auto &target_input : next_op_input_) { auto target_input_lyt = target_input.second[next_op_input_index_].tensor_layout(); auto target_input_str = target_input.first; CostPtr cost; @@ -99,8 +99,8 @@ Status Edge::InitEdgeCost() { return Status::SUCCESS; } -Status Edge::GetRedistributionCost(const TensorLayout& prev_op_output_layout, const TensorLayout& next_op_input_layout, - size_t type_length, TypePtr type, CostPtr* cost) { +Status Edge::GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t type_length, TypePtr type, CostPtr *cost) { MS_EXCEPTION_IF_NULL(prev_op_); MS_EXCEPTION_IF_NULL(cost); RankList dev_list = prev_op_->global_device_list(); @@ -148,9 +148,9 @@ CostPtrList Edge::GetCostList(StrategyPtr output_str, StrategyPtr input_str) { return result; } -CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr, const std::vector& edges, - const StrategyPtr& input_st_ptr) { - std::function LocalGetCostList = [&](const EdgePtr& edge) { +CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, const std::vector &edges, + const StrategyPtr &input_st_ptr) { + std::function LocalGetCostList = [&](const EdgePtr &edge) { MS_EXCEPTION_IF_NULL(edge); return edge->GetCostList(output_st_ptr, input_st_ptr); }; @@ -174,7 +174,7 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr result.push_back(new_cost); return; } - for (auto& c : all_cost_list[k]) { + for (auto &c : all_cost_list[k]) { MS_EXCEPTION_IF_NULL(c); selected_cost_list[k] = c; recursive(k + 1, computation + c->computation_cost_, memory + c->memory_with_reuse_, @@ -187,11 +187,11 @@ CostPtrList Edge::CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr return result; } -void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector& edges, OperatorInfoPtr) { +void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector &edges, OperatorInfoPtr) { bool valid = false; - for (const auto& output_pair : pre_op_output_) { + for (const auto &output_pair : pre_op_output_) { StrategyPtr output_st_ptr = output_pair.first; - for (const auto& input_pair : next_op_input_) { + for (const auto &input_pair : next_op_input_) { StrategyPtr input_st_ptr = input_pair.first; CostPtrList clist = CreateEdgeEliminationCostList(output_st_ptr, edges, input_st_ptr); CostPtrKey key = {output_st_ptr, input_st_ptr}; @@ -206,14 +206,14 @@ void Edge::EdgeEliminationSetNewCost(OperatorInfoPtr, const std::vector } } -void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& left_cost_list, - const CostPtrList& middle_cost_list, const CostPtrList& right_cost_list, - CostPtrList* ret_cost_list) { - for (auto& left_cost : left_cost_list) { +void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list) { + for (auto &left_cost : left_cost_list) { MS_EXCEPTION_IF_NULL(left_cost); - for (auto& middle_cost : middle_cost_list) { + for (auto &middle_cost : middle_cost_list) { MS_EXCEPTION_IF_NULL(middle_cost); - for (auto& right_cost : right_cost_list) { + for (auto &right_cost : right_cost_list) { MS_EXCEPTION_IF_NULL(right_cost); double computation = left_cost->computation_cost_ + middle_cost->computation_cost_ + right_cost->computation_cost_; @@ -238,14 +238,14 @@ void Edge::CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtr } } -CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr& e1, const StrategyPtr& output_st_ptr, - const OperatorInfoPtr& op, const EdgePtr& e2, - const StrategyPtr& input_st_ptr) { +CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr &e1, const StrategyPtr &output_st_ptr, + const OperatorInfoPtr &op, const EdgePtr &e2, + const StrategyPtr &input_st_ptr) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(e1); MS_EXCEPTION_IF_NULL(e2); CostPtrList result; - for (const auto& op_strategy : op->GetStrategyCost()) { + for (const auto &op_strategy : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_strategy); auto middle_strategy = op_strategy->strategy_ptr; CreateOpEliminationSubCostList(middle_strategy, e1->GetCostList(output_st_ptr, middle_strategy), @@ -255,11 +255,11 @@ CostPtrList Edge::CreateOpEliminationCostList(const EdgePtr& e1, const StrategyP return result; } -void Edge::OpEliminationSetNewCost(const EdgePtr& e1, const OperatorInfoPtr& op, const EdgePtr& e2) { +void Edge::OpEliminationSetNewCost(const EdgePtr &e1, const OperatorInfoPtr &op, const EdgePtr &e2) { bool valid = false; - for (const auto& output_pair : pre_op_output_) { + for (const auto &output_pair : pre_op_output_) { StrategyPtr output_st_ptr = output_pair.first; - for (const auto& input_pair : next_op_input_) { + for (const auto &input_pair : next_op_input_) { StrategyPtr input_st_ptr = input_pair.first; CostPtrList clist = CreateOpEliminationCostList(e1, output_st_ptr, op, e2, input_st_ptr); @@ -283,8 +283,8 @@ Status Edge::CalculateMemoryCost() { if (is_output_parameter_involve_ == 0) { // In this case, it is sure that the tensor redistribution along this edge is NOT parameter-involved, thus it is // unnecessary to keep them in memory. - for (auto& cost_kv : cost_map_) { - auto& cost_v = cost_kv.second; + for (auto &cost_kv : cost_map_) { + auto &cost_v = cost_kv.second; if (!cost_v.empty()) { cost_v[0]->memory_with_reuse_ = 0; } diff --git a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h index f974125749..e760c24c34 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/edge_costmodel.h @@ -37,9 +37,9 @@ using EdgePtr = std::shared_ptr; class Edge { // An 'Edge' connects two Operators in the CostGraph. public: - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const size_t& output_index_, const size_t& input_index_, - const bool& is_com) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -49,9 +49,9 @@ class Edge { is_identity_edge = false; } - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const size_t& output_index_, const size_t& input_index_, - const bool& is_com, const bool& is_iden) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const size_t &output_index_, const size_t &input_index_, + const bool &is_com, const bool &is_iden) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -60,9 +60,9 @@ class Edge { is_combined_(is_com), is_identity_edge(is_iden) {} - Edge(const std::string& edge_name, const std::shared_ptr& prev_op, - const std::shared_ptr& next_op, const std::vector& output_indexs_, - const std::vector& input_indexs_, const bool& is_com) + Edge(const std::string &edge_name, const std::shared_ptr &prev_op, + const std::shared_ptr &next_op, const std::vector &output_indexs_, + const std::vector &input_indexs_, const bool &is_com) : edge_name_(edge_name), prev_op_(prev_op), next_op_(next_op), @@ -83,13 +83,13 @@ class Edge { // For two operators u--->v, given the output tensor layout of u, // and the input tensor layout of v, return the redistribution cost, // and the op_list to carry out the redistribution. - Status GetRedistributionCost(const TensorLayout& prev_op_output_layout, const TensorLayout& next_op_input_layout, - size_t, TypePtr type, CostPtr* cost); + Status GetRedistributionCost(const TensorLayout &prev_op_output_layout, const TensorLayout &next_op_input_layout, + size_t, TypePtr type, CostPtr *cost); - void set_pre_op_output(const std::vector, std::vector>>& output_set) { + void set_pre_op_output(const std::vector, std::vector>> &output_set) { pre_op_output_ = output_set; } - void set_next_op_input(const std::vector, std::vector>>& input_set) { + void set_next_op_input(const std::vector, std::vector>> &input_set) { next_op_input_ = input_set; } @@ -109,27 +109,27 @@ class Edge { std::vector prev_op_output_indexs() const { return pre_op_output_indexs_; } std::vector next_op_input_indexs() const { return next_op_input_indexs_; } - CostPtrList CreateEdgeEliminationCostList(const StrategyPtr& output_st_ptr, - const std::vector>& edges, - const StrategyPtr& input_st_ptr); + CostPtrList CreateEdgeEliminationCostList(const StrategyPtr &output_st_ptr, + const std::vector> &edges, + const StrategyPtr &input_st_ptr); // In the Edge Elimination operation in DP algorithm, 'edges' is replaced by a new edge. This method is used to // set cost for this new edge - void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector>& edges, + void EdgeEliminationSetNewCost(std::shared_ptr u, const std::vector> &edges, std::shared_ptr v); - void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& left_cost_list, - const CostPtrList& middle_cost_list, const CostPtrList& right_cost_list, - CostPtrList* ret_cost_list); + void CreateOpEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &left_cost_list, + const CostPtrList &middle_cost_list, const CostPtrList &right_cost_list, + CostPtrList *ret_cost_list); - CostPtrList CreateOpEliminationCostList(const std::shared_ptr& e1, const StrategyPtr& output_st_ptr, - const std::shared_ptr& op, const std::shared_ptr& e2, - const StrategyPtr& input_st_ptr); + CostPtrList CreateOpEliminationCostList(const std::shared_ptr &e1, const StrategyPtr &output_st_ptr, + const std::shared_ptr &op, const std::shared_ptr &e2, + const StrategyPtr &input_st_ptr); // In the Operation Elimination operation in DP algorithm, 'op', 'e1' and 'e2' are replaced by a new edge. // This method is used to set cost for this new edge - void OpEliminationSetNewCost(const std::shared_ptr& e1, const std::shared_ptr& op, - const std::shared_ptr& e2); + void OpEliminationSetNewCost(const std::shared_ptr &e1, const std::shared_ptr &op, + const std::shared_ptr &e2); - void set_selected_cost(const CostPtr& cost) { selected_cost_ = cost; } - const CostPtr& selected_cost() const { return selected_cost_; } + void set_selected_cost(const CostPtr &cost) { selected_cost_ = cost; } + const CostPtr &selected_cost() const { return selected_cost_; } void set_parameter_involve(int para_invol) { is_output_parameter_involve_ = para_invol; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc index c56d3a6fbd..501a983a95 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.cc @@ -144,7 +144,7 @@ void CostGraph::SetDeviceMemoryAndCostParameter() { } } -void CostGraph::RemoveOperator(const OperatorInfoPtr& op) { +void CostGraph::RemoveOperator(const OperatorInfoPtr &op) { for (auto it = ops_.begin(); it != ops_.end();) { if ((*it) == op) { it = ops_.erase(it); @@ -154,19 +154,19 @@ void CostGraph::RemoveOperator(const OperatorInfoPtr& op) { } } -bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr& op_test) { +bool CostGraph::IsOperatorInCostGraph(const OperatorInfoPtr &op_test) { struct IsInGraph { const OperatorInfoPtr test_; - explicit IsInGraph(const OperatorInfoPtr& n) : test_(n) {} - bool operator()(const OperatorInfoPtr& in) const { return (test_ == in); } + explicit IsInGraph(const OperatorInfoPtr &n) : test_(n) {} + bool operator()(const OperatorInfoPtr &in) const { return (test_ == in); } }; return std::any_of(ops_.begin(), ops_.end(), IsInGraph(op_test)); } -bool CostGraph::IsEdgeInCostGraph(const std::string& test_edge_name, size_t output_index, size_t input_index) { - for (auto& edge_pair : edges_) { +bool CostGraph::IsEdgeInCostGraph(const std::string &test_edge_name, size_t output_index, size_t input_index) { + for (auto &edge_pair : edges_) { auto edges = edge_pair.second; - for (auto& edge : edges) { + for (auto &edge : edges) { MS_EXCEPTION_IF_NULL(edge); bool bool_result = (edge->edge_name() == test_edge_name) && (edge->prev_op_output_index() == output_index) && (edge->next_op_input_index() == input_index); @@ -182,12 +182,12 @@ std::vector> CostGraph::ConstructConnectedComponents( std::vector alive_ops) { std::map visited; - for (auto& op : alive_ops) { + for (auto &op : alive_ops) { visited[op] = false; } MS_LOG(INFO) << "visited: " << visited.size() << "."; - for (auto& op : alive_ops) { + for (auto &op : alive_ops) { if ((!visited[op]) && op->is_alive()) { std::shared_ptr new_component = std::make_shared(); MS_EXCEPTION_IF_NULL(new_component); @@ -199,14 +199,14 @@ std::vector> CostGraph::ConstructConnectedComponents( return connected_compoents_; } -void CostGraph::DFS(const OperatorInfoPtr& current_op, std::map* visited, - const std::shared_ptr& component) { +void CostGraph::DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component) { MS_EXCEPTION_IF_NULL(visited); MS_EXCEPTION_IF_NULL(component); visited->at(current_op) = true; component->AddOperator(current_op); - for (auto& edge : current_op->succ_edges()) { + for (auto &edge : current_op->succ_edges()) { bool bool_test = (visited->find(edge->next_operator()) != visited->end()) && (!visited->at(edge->next_operator())) && edge->next_operator()->is_alive(); if (bool_test) { @@ -215,7 +215,7 @@ void CostGraph::DFS(const OperatorInfoPtr& current_op, std::mapprev_edges()) { + for (auto &edge : current_op->prev_edges()) { bool bool_test = (visited->find(edge->prev_operator()) != visited->end()) && (!visited->at(edge->prev_operator())) && edge->prev_operator()->is_alive(); if (bool_test) { @@ -226,14 +226,14 @@ void CostGraph::DFS(const OperatorInfoPtr& current_op, std::map v -CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std::shared_ptr& e, - const OperatorInfoPtr& v) { +CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr &u, const std::shared_ptr &e, + const OperatorInfoPtr &v) { MS_EXCEPTION_IF_NULL(u); MS_EXCEPTION_IF_NULL(v); MS_EXCEPTION_IF_NULL(e); CostPtrList ret; - for (const auto& u_strategy : u->GetStrategyCost()) { - for (const auto& v_strategy : v->GetStrategyCost()) { + for (const auto &u_strategy : u->GetStrategyCost()) { + for (const auto &v_strategy : v->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(u_strategy); MS_EXCEPTION_IF_NULL(v_strategy); auto u_strategy_ptr = u_strategy->strategy_ptr; @@ -241,9 +241,9 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: CostPtrList clist1 = u_strategy->cost_list; CostPtrList clist2 = e->GetCostList(u_strategy_ptr, v_strategy_ptr); CostPtrList clist3 = v_strategy->cost_list; - for (const auto& cost1 : clist1) { - for (const auto& cost2 : clist2) { - for (const auto& cost3 : clist3) { + for (const auto &cost1 : clist1) { + for (const auto &cost2 : clist2) { + for (const auto &cost3 : clist3) { MS_EXCEPTION_IF_NULL(cost1); MS_EXCEPTION_IF_NULL(cost2); MS_EXCEPTION_IF_NULL(cost3); @@ -274,14 +274,14 @@ CostPtrList CostGraph::CreateFinalCostList(const OperatorInfoPtr& u, const std:: } // Create final cost list for the graph containing a signle node: u -CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { +CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr &u) { MS_EXCEPTION_IF_NULL(u); CostPtrList ret; - for (const auto& u_strategy : u->GetStrategyCost()) { + for (const auto &u_strategy : u->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(u_strategy); auto u_strategy_ptr = u_strategy->strategy_ptr; CostPtrList clist1 = u_strategy->cost_list; - for (const auto& cost1 : clist1) { + for (const auto &cost1 : clist1) { MS_EXCEPTION_IF_NULL(cost1); auto decision = std::make_shared(u_strategy_ptr, cost1); auto new_cost = std::make_shared(cost1->computation_cost_, cost1->communication_cost_, decision); @@ -299,16 +299,16 @@ CostPtrList CostGraph::CreateFinalSingleCostList(const OperatorInfoPtr& u) { return ret; } -CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory) { +CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory) { CostPtrList after_mem_filter; // Filter out the valid costs - for (auto& a_cost : cost_list) { + for (auto &a_cost : cost_list) { if (a_cost->memory_with_reuse_ <= memory) { after_mem_filter.emplace_back(std::move(a_cost)); } } - std::function LocalCompare = [&](CostPtr init, const CostPtr& cost_x) { + std::function LocalCompare = [&](CostPtr init, const CostPtr &cost_x) { MS_EXCEPTION_IF_NULL(cost_x); if (init == nullptr || cost_x->computation_cost_ < memory) { init = cost_x; @@ -319,7 +319,7 @@ CostPtr CostGraph::SelectCostWithMemoryConstraint(const CostPtrList& cost_list, return std::accumulate(after_mem_filter.begin(), after_mem_filter.end(), ret, LocalCompare); } -CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory) { +CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory) { // Select the cost with minimum training time. Currently, the training time is modeled as = // costmodel_alpha_ * computation_cost + costmodel_beta_ * communication_with_partial_para_ if (cost_list.empty()) { @@ -329,7 +329,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d CostPtrList after_mem_filter; double minimum_memory = DBL_MAX; // Filter out the valid costs. - for (auto& a_cost : cost_list) { + for (auto &a_cost : cost_list) { if (a_cost->memory_with_reuse_ <= memory) { after_mem_filter.emplace_back(std::move(a_cost)); } else if (a_cost->memory_with_reuse_ < minimum_memory) { @@ -371,7 +371,7 @@ CostPtr CostGraph::SelectCostWithMinTrainingTime(const CostPtrList& cost_list, d return ret; } -CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector& all_cost_list, +CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_cost_list, double available_memory) { CostPtrList selected_cost_list(all_cost_list.size(), nullptr); double minimum = DBL_MAX, total_memory = 0.0; @@ -418,7 +418,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect } MS_LOG(DEBUG) << "The value minimum: " << minimum << ", available_memory: " << available_memory << "."; - for (auto& c : all_cost_list[k]) { + for (auto &c : all_cost_list[k]) { selected_cost_list[k] = c; recursive(k + 1); } @@ -427,7 +427,7 @@ CostPtrList CostGraph::SelectCostListWithMinTrainingTimeMultiple(const std::vect return ret; } -Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector& alive_ops) { +Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector &alive_ops) { MS_LOG(INFO) << "There are " << alive_ops.size() << " nodes in the final graph."; auto connected_components = ConstructConnectedComponents(alive_ops); MS_LOG(INFO) << "There are " << connected_components.size() << " components in the final graph."; @@ -516,7 +516,7 @@ Status CostGraph::SearchStrategyForMultiNodeFinalGraph(const std::vector alive_ops; - (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr& op) { + (void)std::for_each(ops_.begin(), ops_.end(), [&alive_ops](const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); if (op->is_alive()) { alive_ops.push_back(op); @@ -620,7 +620,7 @@ Status CostGraph::SearchStrategy() { // Given a graph which contains the following subgraph: u --> v --> w, the node v can be eliminated // return the v and the edge u --> v OperatorInfoPtr CostGraph::CheckOpElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { bool bool_test = op->is_alive() && op->GetAliveSuccEdges().size() == 1 && op->GetAlivePrevEdges().size() == 1; if (bool_test) { if ((op->GetAliveSuccEdges()[0]->next_operator() != op) && (op->GetAlivePrevEdges()[0]->prev_operator() != op)) { @@ -633,21 +633,21 @@ OperatorInfoPtr CostGraph::CheckOpElimination() const { // Check the graph whether an EdgeElimination can be performed std::vector> CostGraph::CheckEdgeElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); if (!op->is_alive()) continue; - std::map count; - for (auto& edge : op->GetAliveSuccEdges()) { + std::map count; + for (auto &edge : op->GetAliveSuccEdges()) { MS_EXCEPTION_IF_NULL(edge); auto v = edge->next_operator(); count[v.get()]++; } - for (auto& pair : count) { - auto* op_ptr = pair.first; + for (auto &pair : count) { + auto *op_ptr = pair.first; int op_count = pair.second; if (op_count > 1) { std::vector> ret; - for (auto& edge : op->GetAliveSuccEdges()) { + for (auto &edge : op->GetAliveSuccEdges()) { MS_EXCEPTION_IF_NULL(edge); if (edge->next_operator().get() == op_ptr) { ret.push_back(edge); @@ -662,7 +662,7 @@ std::vector> CostGraph::CheckEdgeElimination() const { // Check the graph whether a MergeElimination can be performed OperatorInfoPtr CostGraph::CheckMergeElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = op->is_alive() && op->GetAlivePrevEdges().empty() && op->GetAliveSuccEdges().size() == 1; if (bool_test) { @@ -678,7 +678,7 @@ OperatorInfoPtr CostGraph::CheckMergeElimination() const { // Check the graph whether a ContractElimination can be performed OperatorInfoPtr CostGraph::CheckContractElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = op->is_alive() && op->GetAlivePrevEdges().size() == 1 && op->GetAliveSuccEdges().empty(); if (bool_test) { @@ -696,7 +696,7 @@ OperatorInfoPtr CostGraph::CheckContractElimination() const { // Check the graph whether a TriangleElimination can be performed std::pair> CostGraph::CheckTriangleElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() == 2); if (bool_test) { @@ -707,13 +707,13 @@ std::pair> CostGraph::CheckTriangleElimin auto first_op = edge1->next_operator(); auto second_op = edge2->next_operator(); MS_EXCEPTION_IF_NULL(first_op); - for (auto& first_op_succ_edge : first_op->GetAliveSuccEdges()) { + for (auto &first_op_succ_edge : first_op->GetAliveSuccEdges()) { if (first_op_succ_edge->next_operator() == second_op) { return {op, first_op_succ_edge}; } } MS_EXCEPTION_IF_NULL(second_op); - for (auto& second_op_succ_edge : second_op->GetAliveSuccEdges()) { + for (auto &second_op_succ_edge : second_op->GetAliveSuccEdges()) { if (second_op_succ_edge->next_operator() == first_op) { return {op, second_op_succ_edge}; } @@ -726,7 +726,7 @@ std::pair> CostGraph::CheckTriangleElimin // Check the graph whether a StarElimination can be performed. // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. OperatorInfoPtr CostGraph::CheckStarElimination() const { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); bool bool_test = (op->is_alive()) && (op->GetAlivePrevEdges().empty()) && (op->GetAliveSuccEdges().size() > 1); if (bool_test) { @@ -738,7 +738,7 @@ OperatorInfoPtr CostGraph::CheckStarElimination() const { // This method is for 'eliminating operator' operation in the DP algorithm. It creates a new edge to replace // 'lefe_edge', 'op' and 'right_edge'. As a consequence, it creates new costlist for the new edge. -std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr& op) { +std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr &op) { // in this case, the operators are organised in the form of u-->op-->v, and the goal // is to eliminate 'op'. MS_EXCEPTION_IF_NULL(op); @@ -786,7 +786,7 @@ std::shared_ptr CostGraph::EliminationOp(const OperatorInfoPtr& op) { // This method is for 'eliminating edges' operation in the DP algorithm. It creates a new edge to replace the 'edges', // and sets new costlist for the new edge. -std::shared_ptr CostGraph::EliminationEdges(const std::vector>& edges) { +std::shared_ptr CostGraph::EliminationEdges(const std::vector> &edges) { MS_LOG(INFO) << "Now eliminating " << edges.size() << " edges."; MS_EXCEPTION_IF_NULL(edges[0]); auto u = edges[0]->prev_operator(); @@ -796,7 +796,7 @@ std::shared_ptr CostGraph::EliminationEdges(const std::vectorname() + OPERATOR_TO_OPERATOR_CONNECTOR + v->name(); std::vector output_indexs, input_indexs; - for (auto& edge : edges) { + for (auto &edge : edges) { MS_EXCEPTION_IF_NULL(edge); if (edge->is_combined()) { auto from_output_indexs = edge->prev_op_output_indexs(); @@ -824,18 +824,18 @@ std::shared_ptr CostGraph::EliminationEdges(const std::vectorcomputation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; double memory = op_cost->memory_with_reuse_ + edge_cost->memory_with_reuse_ + tar_cost->memory_with_reuse_; @@ -862,7 +862,7 @@ void CostGraph::CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const // This method is for the 'Merge' operation in DP algorithm. It creates new costlist for each strategy in the // target_op -OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { +OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); auto target_op = op->GetAliveSuccEdges()[0]->next_operator(); auto edge_ptr = op->GetAliveSuccEdges()[0]; @@ -871,13 +871,13 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { MS_LOG(INFO) << "Now merging " << op->name() << " into " << target_op->name() << "."; bool valid = false; - for (auto& tar_stra_cost : target_op->GetStrategyCost()) { + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(tar_stra_cost); auto tar_stra = tar_stra_cost->strategy_ptr; auto tar_clist_origin = tar_stra_cost->cost_list; CostPtrList tar_clist_new; - for (auto& op_stra_cost : op->GetStrategyCost()) { + for (auto &op_stra_cost : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_stra_cost); auto op_stra = op_stra_cost->strategy_ptr; auto op_clist = op_stra_cost->cost_list; @@ -904,17 +904,17 @@ OperatorInfoPtr CostGraph::EliminationMerge(const OperatorInfoPtr& op) { // Given 'contract_op_cost_list', 'edge_cost_list', and 'tar_cost_list', this method is to create 'tar_cost_list_new' // for this contract under the strategy 'contract_op_stra' void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_stra, - const CostPtrList& contract_op_cost_list, - const CostPtrList& edge_cost_list, StrategyPtr target_op_stra, - const CostPtrList& tar_cost_list, CostPtrList* tar_cost_list_new) { + const CostPtrList &contract_op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr target_op_stra, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new) { for (size_t i = 0; i < contract_op_cost_list.size(); ++i) { - auto& contract_op_cost = contract_op_cost_list[i]; + auto &contract_op_cost = contract_op_cost_list[i]; MS_EXCEPTION_IF_NULL(contract_op_cost); for (size_t j = 0; j < edge_cost_list.size(); ++j) { - auto& edge_cost = edge_cost_list[j]; + auto &edge_cost = edge_cost_list[j]; MS_EXCEPTION_IF_NULL(edge_cost); for (size_t k = 0; k < tar_cost_list.size(); ++k) { - auto& tar_cost = tar_cost_list[k]; + auto &tar_cost = tar_cost_list[k]; MS_EXCEPTION_IF_NULL(tar_cost); double computation = contract_op_cost->computation_cost_ + edge_cost->computation_cost_ + tar_cost->computation_cost_; @@ -941,20 +941,20 @@ void CostGraph::CreateContractEliminationSubCostList(StrategyPtr contract_op_str // This method is for the 'Contract' operation in DP algorithm. It creates new costlist for each strategy in the // target_op -OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { +OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr &op) { MS_EXCEPTION_IF_NULL(op); auto target_op = op->GetAlivePrevEdges()[0]->prev_operator(); auto edge_ptr = op->GetAlivePrevEdges()[0]; MS_LOG(INFO) << "Now contracting " << op->name() << " into " << target_op->name() << "."; bool valid = false; - for (auto& tar_stra_cost : target_op->GetStrategyCost()) { + for (auto &tar_stra_cost : target_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(tar_stra_cost); auto tar_stra = tar_stra_cost->strategy_ptr; auto tar_clist_origin = tar_stra_cost->cost_list; CostPtrList tar_clist_new; - for (auto& op_stra_cost : op->GetStrategyCost()) { + for (auto &op_stra_cost : op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(op_stra_cost); auto op_stra = op_stra_cost->strategy_ptr; auto op_clist = op_stra_cost->cost_list; @@ -978,19 +978,19 @@ OperatorInfoPtr CostGraph::EliminationContract(const OperatorInfoPtr& op) { } void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, StrategyPtr left_op_stra, - StrategyPtr right_op_stra, const CostPtr& right_op_cost, - const CostPtrList& elimi_op_clist, - const CostPtrList& left_edge_clist, const CostPtr& right_edge_cost, - const CostPtrList& left_node_clist_origin, - CostPtrList* left_node_clist_new) { + StrategyPtr right_op_stra, const CostPtr &right_op_cost, + const CostPtrList &elimi_op_clist, + const CostPtrList &left_edge_clist, const CostPtr &right_edge_cost, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { MS_EXCEPTION_IF_NULL(right_edge_cost); MS_EXCEPTION_IF_NULL(right_op_cost); MS_EXCEPTION_IF_NULL(left_node_clist_new); - for (auto& elimi_op_cost : elimi_op_clist) { + for (auto &elimi_op_cost : elimi_op_clist) { MS_EXCEPTION_IF_NULL(elimi_op_cost); - for (auto& left_edge_cost : left_edge_clist) { + for (auto &left_edge_cost : left_edge_clist) { MS_EXCEPTION_IF_NULL(left_edge_cost); - for (auto& left_node_cost : left_node_clist_origin) { + for (auto &left_node_cost : left_node_clist_origin) { MS_EXCEPTION_IF_NULL(left_node_cost); double new_computation = elimi_op_cost->computation_cost_ + left_edge_cost->computation_cost_ + left_node_cost->computation_cost_ + right_edge_cost->computation_cost_; @@ -1015,16 +1015,16 @@ void CostGraph::CreateTriangleEliminationSubCostList(StrategyPtr elimi_op_stra, } } -void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_op, const CostPtrList& right_node_clist, - const CostPtrList& right_edge_clist, const StrategyPtr& elimi_op_stra, - const StrategyPtr& left_node_stra, const StrategyPtr& right_node_stra, - const CostPtrList& elimi_op_clist, const CostPtrList& left_edge_clist, - const CostPtrList& left_node_clist_origin, - CostPtrList* left_node_clist_new) { +void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr &elimi_op, const CostPtrList &right_node_clist, + const CostPtrList &right_edge_clist, const StrategyPtr &elimi_op_stra, + const StrategyPtr &left_node_stra, const StrategyPtr &right_node_stra, + const CostPtrList &elimi_op_clist, const CostPtrList &left_edge_clist, + const CostPtrList &left_node_clist_origin, + CostPtrList *left_node_clist_new) { MS_EXCEPTION_IF_NULL(elimi_op); - for (auto& right_node_cost : right_node_clist) { + for (auto &right_node_cost : right_node_clist) { MS_EXCEPTION_IF_NULL(right_node_cost); - for (auto& right_edge_cost : right_edge_clist) { + for (auto &right_edge_cost : right_edge_clist) { MS_EXCEPTION_IF_NULL(right_edge_cost); CreateTriangleEliminationSubCostList(elimi_op_stra, left_node_stra, right_node_stra, right_node_cost, elimi_op_clist, left_edge_clist, right_edge_cost, left_node_clist_origin, @@ -1033,8 +1033,8 @@ void CostGraph::CreateTriangleEliminationCostList(const OperatorInfoPtr& elimi_o } } -OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, - const std::shared_ptr& edge_left_right) { +OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr &elimi_op, + const std::shared_ptr &edge_left_right) { MS_EXCEPTION_IF_NULL(edge_left_right); MS_EXCEPTION_IF_NULL(elimi_op); MS_LOG(INFO) << "Now eliminating triangle: " << elimi_op->name() << "."; @@ -1056,19 +1056,19 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, } bool valid = false; - for (auto& left_node_stra_cost : left_node->GetStrategyCost()) { + for (auto &left_node_stra_cost : left_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(left_node_stra_cost); auto left_node_stra = left_node_stra_cost->strategy_ptr; auto left_node_clist_origin = left_node_stra_cost->cost_list; CostPtrList left_node_clist_new; - for (auto& elimi_op_stra_cost : elimi_op->GetStrategyCost()) { + for (auto &elimi_op_stra_cost : elimi_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(elimi_op_stra_cost); auto elimi_op_stra = elimi_op_stra_cost->strategy_ptr; auto elimi_op_clist = elimi_op_stra_cost->cost_list; auto left_edge_clist = left_edge->GetCostList(elimi_op_stra, left_node_stra); - for (auto& right_node_stra_cost : right_node->GetStrategyCost()) { + for (auto &right_node_stra_cost : right_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(right_node_stra_cost); auto right_node_stra = right_node_stra_cost->strategy_ptr; auto right_node_clist = right_node_stra_cost->cost_list; @@ -1095,16 +1095,16 @@ OperatorInfoPtr CostGraph::EliminationTriangle(const OperatorInfoPtr& elimi_op, return left_node; } -void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_node_stra, - const CostPtrList& first_succ_node_clist, - const CostPtrList& first_succ_edge_clist, - const StrategyPtr& merged_op_stra, const CostPtrList& merged_op_clist, +void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, std::vector succ_nodes_stras, - CostPtrList& succ_edges_costs, CostPtrList& succ_nodes_costs, - CostPtrList* first_succ_node_clist_new) { - for (auto& first_succ_node_cost : first_succ_node_clist) { - for (auto& first_succ_edge_cost : first_succ_edge_clist) { - for (auto& merged_node_cost : merged_op_clist) { + CostPtrList &succ_edges_costs, CostPtrList &succ_nodes_costs, + CostPtrList *first_succ_node_clist_new) { + for (auto &first_succ_node_cost : first_succ_node_clist) { + for (auto &first_succ_edge_cost : first_succ_edge_clist) { + for (auto &merged_node_cost : merged_op_clist) { MS_EXCEPTION_IF_NULL(merged_node_cost); succ_nodes_stras[0] = first_succ_node_stra; succ_edges_costs[0] = first_succ_edge_cost; @@ -1141,12 +1141,12 @@ void CostGraph::CreateStarEliminationSubCostList(const StrategyPtr& first_succ_n } } -void CostGraph::CreateStarEliminationCostList(std::vector>& succ_edges, - const StrategyPtr& first_succ_node_stra, - const CostPtrList& first_succ_node_clist, - const CostPtrList& first_succ_edge_clist, - const StrategyPtr& merged_op_stra, const CostPtrList& merged_op_clist, - CostPtrList* first_succ_node_clist_new) { +void CostGraph::CreateStarEliminationCostList(std::vector> &succ_edges, + const StrategyPtr &first_succ_node_stra, + const CostPtrList &first_succ_node_clist, + const CostPtrList &first_succ_edge_clist, + const StrategyPtr &merged_op_stra, const CostPtrList &merged_op_clist, + CostPtrList *first_succ_node_clist_new) { std::vector succ_nodes_stras(succ_edges.size(), nullptr); CostPtrList succ_edges_costs(succ_edges.size(), nullptr), succ_nodes_costs(succ_edges.size(), nullptr); std::function recursive = [&first_succ_node_stra, &first_succ_node_clist, &first_succ_edge_clist, @@ -1167,15 +1167,15 @@ void CostGraph::CreateStarEliminationCostList(std::vector> MS_EXCEPTION_IF_NULL(succ_edge); auto succ_node = succ_edge->next_operator(); MS_EXCEPTION_IF_NULL(succ_node); - for (auto& succ_node_stra_cost : succ_node->GetStrategyCost()) { + for (auto &succ_node_stra_cost : succ_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(succ_node_stra_cost); auto succ_node_stra = succ_node_stra_cost->strategy_ptr; auto succ_node_clist = succ_node_stra_cost->cost_list; auto succ_edge_clist = succ_edge->GetCostList(merged_op_stra, succ_node_stra); - for (auto& succ_node_cost : succ_node_clist) { + for (auto &succ_node_cost : succ_node_clist) { MS_EXCEPTION_IF_NULL(succ_node_cost); - for (auto& succ_edge_cost : succ_edge_clist) { + for (auto &succ_edge_cost : succ_edge_clist) { MS_EXCEPTION_IF_NULL(succ_edge_cost); succ_nodes_stras[k] = succ_node_stra; succ_edges_costs[k] = succ_edge_cost; @@ -1189,11 +1189,11 @@ void CostGraph::CreateStarEliminationCostList(std::vector> recursive(1); } -std::vector> CostGraph::EliminationStar(const OperatorInfoPtr& merged_op) { +std::vector> CostGraph::EliminationStar(const OperatorInfoPtr &merged_op) { MS_EXCEPTION_IF_NULL(merged_op); auto succ_edges = merged_op->GetAliveSuccEdges(); MS_LOG(INFO) << "Now eliminating star centered at: " << merged_op->name() << "."; - for (auto& succ_edge : succ_edges) { + for (auto &succ_edge : succ_edges) { MS_EXCEPTION_IF_NULL(succ_edge->next_operator()); MS_LOG(INFO) << "The successive operator is: " << succ_edge->next_operator()->name() << "."; } @@ -1205,13 +1205,13 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo // 'merged_op' is merged into first_node MS_EXCEPTION_IF_NULL(first_succ_node); - for (auto& first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { + for (auto &first_succ_node_stra_cost : first_succ_node->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(first_succ_node_stra_cost); auto first_succ_node_stra = first_succ_node_stra_cost->strategy_ptr; auto first_succ_node_clist = first_succ_node_stra_cost->cost_list; CostPtrList first_succ_node_clist_new; - for (auto& merged_op_stra_cost : merged_op->GetStrategyCost()) { + for (auto &merged_op_stra_cost : merged_op->GetStrategyCost()) { MS_EXCEPTION_IF_NULL(merged_op_stra_cost); auto merged_op_stra = merged_op_stra_cost->strategy_ptr; auto merged_op_clist = merged_op_stra_cost->cost_list; @@ -1238,7 +1238,7 @@ std::vector> CostGraph::EliminationStar(const OperatorInfo } Status CostGraph::InitSelectedStrategy() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); auto result = op->InitSelectedStrategy(op->selected_strategy()); if (result != SUCCESS) { @@ -1249,9 +1249,9 @@ Status CostGraph::InitSelectedStrategy() { } Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); - const auto& output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); + const auto &output_parameter = op->ComputeOpAndPrevEdgeParameterInvolved(); if ((output_parameter != 0) && (output_parameter != 1)) { MS_LOG(ERROR) << "Computing parameter_involved for " << op->name() << " failed."; return FAILED; @@ -1261,7 +1261,7 @@ Status CostGraph::ComputeOpsAndEdgesParameterInvolved() { } Status CostGraph::CalculateOpsMemoryCost() { - for (auto& op : ops_) { + for (auto &op : ops_) { MS_EXCEPTION_IF_NULL(op); if (op->CalculateMemoryCost() != SUCCESS) { MS_LOG(ERROR) << "Calculate Operator: " << op->name() << " cost for memory usage failed."; @@ -1272,9 +1272,9 @@ Status CostGraph::CalculateOpsMemoryCost() { } Status CostGraph::CalculateEdgesMemoryCost() { - for (auto& edge_pair : edges_) { - const auto& edges = edge_pair.second; - for (auto& one_edge : edges) { + for (auto &edge_pair : edges_) { + const auto &edges = edge_pair.second; + for (auto &one_edge : edges) { if (one_edge->CalculateMemoryCost() != SUCCESS) { MS_LOG(ERROR) << "Calculate Edge: " << one_edge->edge_name() << " cost for memory usage failed."; return FAILED; @@ -1284,7 +1284,7 @@ Status CostGraph::CalculateEdgesMemoryCost() { return SUCCESS; } -OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) const { +OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string &p_name) const { for (auto one_op : ops_) { if (one_op->name().find(IDENTITY_INFO) != std::string::npos) { if (one_op->refkey_parameter_name() == p_name) { @@ -1295,7 +1295,7 @@ OperatorInfoPtr CostGraph::FindTmpIdentityByParameterName(std::string& p_name) c return nullptr; } Status CostGraph::CorrectOpsMemoryCost() { - for (auto& one_op : ops_) { + for (auto &one_op : ops_) { if ((one_op->name().find(IDENTITY_INFO) != std::string::npos) && (one_op->is_output_parameter_involve() == 1)) { if (one_op->GetAliveSuccEdges().size() > 1) { // Filter out the case when the TmpIdentity being used by multiple operators diff --git a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h index e701a377b9..530f67ba45 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/graph_costmodel.h @@ -70,7 +70,7 @@ class CostGraph { costmodel_beta_ = DEFAULT_COST_MODEL_BETA; } ~CostGraph() = default; - void AddOperator(const OperatorInfoPtr& op) { ops_.push_back(op); } + void AddOperator(const OperatorInfoPtr &op) { ops_.push_back(op); } OperatorInfoPtr FindOperatorByIndex(size_t index) { if (index >= ops_.size()) { MS_LOG(ERROR) << "The index: " << index << " is out of the range of ops_: " << ops_.size() << "."; @@ -78,29 +78,29 @@ class CostGraph { } return ops_[index]; } - void RemoveOperator(const OperatorInfoPtr& op); - bool IsOperatorInCostGraph(const OperatorInfoPtr& op); + void RemoveOperator(const OperatorInfoPtr &op); + bool IsOperatorInCostGraph(const OperatorInfoPtr &op); // the edge is in the form: u --> v - void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr& edge) { + void AddEdge(OperatorInfoPtr u_node, OperatorInfoPtr v_node, const EdgePtr &edge) { std::vector curr_edges(edges_[{u_node, v_node}]); curr_edges.push_back(edge); edges_[{u_node, v_node}] = curr_edges; } // An edge is uniquely identified by its name, and its output index and input index. - bool IsEdgeInCostGraph(const std::string&, size_t, size_t); + bool IsEdgeInCostGraph(const std::string &, size_t, size_t); void SetDeviceMemoryAndCostParameter(); std::vector> ConstructConnectedComponents(std::vector); - void DFS(const OperatorInfoPtr& current_op, std::map* visited, - const std::shared_ptr& component); + void DFS(const OperatorInfoPtr ¤t_op, std::map *visited, + const std::shared_ptr &component); - CostPtrList CreateFinalCostList(const OperatorInfoPtr& u, const EdgePtr& e, const OperatorInfoPtr& v); - CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr& u); - CostPtr SelectCostWithMemoryConstraint(const CostPtrList& cost_list, double memory); - CostPtr SelectCostWithMinTrainingTime(const CostPtrList& cost_list, double memory); - CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector& all_costlist, double memory); - Status SearchStrategyForMultiNodeFinalGraph(const std::vector&); + CostPtrList CreateFinalCostList(const OperatorInfoPtr &u, const EdgePtr &e, const OperatorInfoPtr &v); + CostPtrList CreateFinalSingleCostList(const OperatorInfoPtr &u); + CostPtr SelectCostWithMemoryConstraint(const CostPtrList &cost_list, double memory); + CostPtr SelectCostWithMinTrainingTime(const CostPtrList &cost_list, double memory); + CostPtrList SelectCostListWithMinTrainingTimeMultiple(const std::vector &all_costlist, double memory); + Status SearchStrategyForMultiNodeFinalGraph(const std::vector &); std::vector> GetOriginalEdgeBetweenOperators(OperatorInfoPtr u_node, OperatorInfoPtr v_node) { return edges_[{u_node, v_node}]; } @@ -145,36 +145,36 @@ class CostGraph { */ OperatorInfoPtr CheckStarElimination() const; // Applying Operator Elimination in DP algorithm - EdgePtr EliminationOp(const OperatorInfoPtr& op); + EdgePtr EliminationOp(const OperatorInfoPtr &op); // Applying Edge Elimination in DP algorithm - EdgePtr EliminationEdges(const std::vector& edges); + EdgePtr EliminationEdges(const std::vector &edges); // Applying Merge Elimination in DP algorithm - OperatorInfoPtr EliminationMerge(const OperatorInfoPtr& op); - void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList& op_cost_list, - const CostPtrList& edge_cost_list, StrategyPtr tar_op_strategy, - const CostPtrList& tar_cost_list, CostPtrList* tar_cost_list_new); + OperatorInfoPtr EliminationMerge(const OperatorInfoPtr &op); + void CreateMergeEliminationSubCostList(StrategyPtr op_strategy, const CostPtrList &op_cost_list, + const CostPtrList &edge_cost_list, StrategyPtr tar_op_strategy, + const CostPtrList &tar_cost_list, CostPtrList *tar_cost_list_new); // Applying Contract Elimination in DP algorithm - OperatorInfoPtr EliminationContract(const OperatorInfoPtr& op); - void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList&, const CostPtrList&, StrategyPtr, - const CostPtrList&, CostPtrList*); + OperatorInfoPtr EliminationContract(const OperatorInfoPtr &op); + void CreateContractEliminationSubCostList(StrategyPtr, const CostPtrList &, const CostPtrList &, StrategyPtr, + const CostPtrList &, CostPtrList *); // Applying Triangle Elimination in DP algorithm. return the left_node - OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr& elimi_op, const EdgePtr& edge_left_right); - void CreateTriangleEliminationCostList(const OperatorInfoPtr&, const CostPtrList&, const CostPtrList&, - const StrategyPtr&, const StrategyPtr&, const StrategyPtr&, const CostPtrList&, - const CostPtrList&, const CostPtrList&, CostPtrList*); + OperatorInfoPtr EliminationTriangle(const OperatorInfoPtr &elimi_op, const EdgePtr &edge_left_right); + void CreateTriangleEliminationCostList(const OperatorInfoPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const StrategyPtr &, const StrategyPtr &, + const CostPtrList &, const CostPtrList &, const CostPtrList &, CostPtrList *); // Given the relevant costlist, create the TriangleElimination cost - void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr&, const CostPtrList&, - const CostPtrList&, const CostPtr&, const CostPtrList&, CostPtrList*); + void CreateTriangleEliminationSubCostList(StrategyPtr, StrategyPtr, StrategyPtr, const CostPtr &, const CostPtrList &, + const CostPtrList &, const CostPtr &, const CostPtrList &, CostPtrList *); // Applying the Star Elimination in DP algorithm. Return the successive edges of this merged_op // NOTE: this elimination MUST be performed only when the above 5 operation cannot be applied. - std::vector EliminationStar(const OperatorInfoPtr& op); - void CreateStarEliminationCostList(std::vector&, const StrategyPtr&, const CostPtrList&, const CostPtrList&, - const StrategyPtr&, const CostPtrList&, CostPtrList*); - void CreateStarEliminationSubCostList(const StrategyPtr&, const CostPtrList&, const CostPtrList&, const StrategyPtr&, - const CostPtrList&, std::vector, CostPtrList&, CostPtrList&, - CostPtrList*); + std::vector EliminationStar(const OperatorInfoPtr &op); + void CreateStarEliminationCostList(std::vector &, const StrategyPtr &, const CostPtrList &, + const CostPtrList &, const StrategyPtr &, const CostPtrList &, CostPtrList *); + void CreateStarEliminationSubCostList(const StrategyPtr &, const CostPtrList &, const CostPtrList &, + const StrategyPtr &, const CostPtrList &, std::vector, + CostPtrList &, CostPtrList &, CostPtrList *); // When the input of a operator is neither a WEIGHT, nor a output of a subsequent operator involving WEIGHT, then // the memory cost can be resused. Status CalculateOpsMemoryCost(); @@ -186,16 +186,16 @@ class CostGraph { std::vector GetOperators() const { return ops_; } size_t GetNumPairs() const { return edges_.size(); } Status InitSelectedStrategy(); - OperatorInfoPtr FindTmpIdentityByParameterName(std::string&) const; + OperatorInfoPtr FindTmpIdentityByParameterName(std::string &) const; // When TmpIdentity is used by mulitple operators, the corresponding parameter's memory cost should be calculated only // once (instead of multiple times), this method is used to correct this. Status CorrectOpsMemoryCost(); // Needed by rec_parser - void add_inputs_tensor_name(const std::vector& inputs_tensor_name) { + void add_inputs_tensor_name(const std::vector &inputs_tensor_name) { inputs_tensor_name_list_.push_back(inputs_tensor_name); } const std::vector> get_inputs_tensor_name_list() const { return inputs_tensor_name_list_; } - void add_tuple_getitem(const std::pair& tuple_getitem) { + void add_tuple_getitem(const std::pair &tuple_getitem) { auto ret = tuple_getitem_list_.insert(tuple_getitem); if (ret.second == false) { MS_LOG(EXCEPTION) << "The insert item is already exist."; diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc index 0192dce8b8..8ad8b46f32 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.cc @@ -23,22 +23,22 @@ namespace mindspore { namespace parallel { -void OperatorCost::set_is_parameter(const std::vector& is_parameter) { is_parameter_ = is_parameter; } +void OperatorCost::set_is_parameter(const std::vector &is_parameter) { is_parameter_ = is_parameter; } -void OperatorCost::set_is_parameter_involve(const std::vector& is_parameter_inv) { +void OperatorCost::set_is_parameter_involve(const std::vector &is_parameter_inv) { is_parameter_involve_ = is_parameter_inv; } void OperatorCost::set_output_parameter_involve(int output_para) { output_parameter_involve_ = output_para; } -void OperatorCost::SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths) { +void OperatorCost::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { inputs_type_lengths_ = input_lengths; outputs_type_lengths_ = output_lengths; } -double OperatorCost::GetMemoryCost(const std::vector& inputs, - const std::vector& outputs) const { +double OperatorCost::GetMemoryCost(const std::vector &inputs, + const std::vector &outputs) const { double result = 0.0; if (output_parameter_involve_ == 1) { // When this operator has multiple outputs, they all contributes to the memory. @@ -64,7 +64,7 @@ double OperatorCost::GetMemoryCost(const std::vector& inputs, } // return the per device communication cost in the forward phase. -double MatMulCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, +double MatMulCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t) const { TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -80,7 +80,7 @@ double MatMulCost::GetForwardCommCost(const std::vector& inputs, con } // return the per device communication cost in the forward phase. -double MatMulCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double MatMulCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { // In backward phase, the communication cost is incurred only when tensor B is a Parameter and tensor B does not // fully utilize all devices @@ -107,8 +107,8 @@ double MatMulCost::GetBackwardCommCost(const std::vector& inputs, co // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t) const { +double MatMulCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t) const { // In forward phase, the compuatation cost = slice(A) + slice(B) + (0 or 1) allreduce(slice(C)) double result = 0.0; TensorInfo output0 = outputs[0]; @@ -126,7 +126,7 @@ double MatMulCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double MatMulCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, +double MatMulCost::GetBackwardComputationCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; @@ -151,14 +151,14 @@ double MatMulCost::GetBackwardComputationCost(const std::vector& inp } // Return the per device communication cost in the forward phase. -double ActivationCost::GetForwardCommCost(const std::vector&, const std::vector&, +double ActivationCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // ReLU is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // Return the per device communication cost in the backward phase. -double ActivationCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ActivationCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -180,7 +180,7 @@ double ActivationCost::GetBackwardCommCost(const std::vector& inputs // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double ActivationCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { TensorInfo input0_info = inputs[0]; Shape input0_slice_shape = input0_info.slice_shape(); @@ -189,19 +189,20 @@ double ActivationCost::GetForwardComputationCost(const std::vector& // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ActivationCost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double ActivationCost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } // Return the per device communication cost in the forward phase. -double SoftmaxCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double SoftmaxCost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // In the forward phase, the communication cost = 0 return 0.0; } // Return the per device communication cost in the backward phase. -double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double SoftmaxCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -223,7 +224,7 @@ double SoftmaxCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double SoftmaxCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In the forward phase, the computation cost = slice(A) TensorInfo input0 = inputs[0]; @@ -233,46 +234,47 @@ double SoftmaxCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double TmpIdentityCost::GetForwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double TmpIdentityCost::GetBackwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // Identity is the element-wise operator, thus it does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double TmpIdentityCost::GetForwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetForwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double TmpIdentityCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double TmpIdentityCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, + int32_t) const { return 0.0; } // Return the per device PEAK memory cost contributed by this operator in a training iteration. -double TmpIdentityCost::GetMemoryCost(const std::vector&, const std::vector&) const { +double TmpIdentityCost::GetMemoryCost(const std::vector &, const std::vector &) const { return 0.0; } -double BatchParallelCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector&, +double BatchParallelCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { double cost = 0.0; for (size_t i = 0; i < inputs.size(); ++i) { @@ -281,13 +283,13 @@ double BatchParallelCost::GetForwardComputationCost(const std::vector&, - const std::vector&, +double BatchParallelCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } -double BatchParallelCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double BatchParallelCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -313,13 +315,13 @@ double BatchParallelCost::GetBackwardCommCost(const std::vector& inp return result; } // return the per device communication cost in the forward phase. -double PReLUCost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double PReLUCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // prelu does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double PReLUCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double PReLUCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[1]) { @@ -342,7 +344,7 @@ double PReLUCost::GetBackwardCommCost(const std::vector& inputs, con // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double PReLUCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -354,8 +356,8 @@ double PReLUCost::GetForwardComputationCost(const std::vector& input // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double PReLUCost::GetBackwardComputationCost(const std::vector& inputs, - const std::vector&, +double PReLUCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { // In backward phase, the computation cost = (0 or 1) allreduce(slice(B)) double result = 0.0; @@ -380,20 +382,21 @@ double PReLUCost::GetBackwardComputationCost(const std::vector&, const std::vector&, int32_t) const { +double OneHotCost::GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const { // onehot does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double OneHotCost::GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double OneHotCost::GetBackwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // onehot does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double OneHotCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In onehot's forward phase, the computation cost = slice(A) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -402,29 +405,29 @@ double OneHotCost::GetForwardComputationCost(const std::vector& inpu // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double OneHotCost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double OneHotCost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetForwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardCommCost(const std::vector &, + const std::vector &, int32_t) const { // SoftmaxCrossEntropyWithLogitsCost does not need communication in the backward phase return 0.0; } // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); Shape input1_slice_shape = inputs[1].slice_shape(); @@ -435,13 +438,13 @@ double SoftmaxCrossEntropyWithLogitsCost::GetForwardComputationCost(const std::v // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double SoftmaxCrossEntropyWithLogitsCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } // return the per device communication cost in the forward phase. -double ReshapeCost::GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, +double ReshapeCost::GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -457,7 +460,7 @@ double ReshapeCost::GetForwardCommCost(const std::vector& inputs, co } // return the per device communication cost in the backward phase. -double ReshapeCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ReshapeCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -479,8 +482,8 @@ double ReshapeCost::GetBackwardCommCost(const std::vector& inputs, c // Return the per device computation cost in the forward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReshapeCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); RankList dev_list = g_device_manager->GetDeviceListByStageId(stage_id); @@ -496,12 +499,12 @@ double ReshapeCost::GetForwardComputationCost(const std::vector& inp // Return the per device computation cost in the backward phase. The cost is calculated according to the bytes // this operator uses -double ReshapeCost::GetBackwardComputationCost(const std::vector&, - const std::vector&, int32_t) const { +double ReshapeCost::GetBackwardComputationCost(const std::vector &, + const std::vector &, int32_t) const { return 0.0; } -double ArithmeticCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double ArithmeticCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { double result; result = ListProduct(inputs[0].slice_shape()) * static_cast(inputs_type_lengths_[0]) + @@ -509,8 +512,8 @@ double ArithmeticCost::GetForwardComputationCost(const std::vector& return result; } -double ArithmeticCost::GetBackwardComputationCost(const std::vector& inputs, const std::vector&, - int32_t stage_id) const { +double ArithmeticCost::GetBackwardComputationCost(const std::vector &inputs, + const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); @@ -544,7 +547,7 @@ double ArithmeticCost::GetBackwardComputationCost(const std::vector& return result; } -double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ArithmeticCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -580,7 +583,7 @@ double ArithmeticCost::GetBackwardCommCost(const std::vector& inputs return result; } -bool IsDataParallel(const Shape& shape, const Shape& slice_shape, int32_t stage_id) { +bool IsDataParallel(const Shape &shape, const Shape &slice_shape, int32_t stage_id) { CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); @@ -589,8 +592,8 @@ bool IsDataParallel(const Shape& shape, const Shape& slice_shape, int32_t stage_ return (total_device_num == IntToSize(strategy0)); } -double ReduceMethodCost::GetForwardCommCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMethodCost::GetForwardCommCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -611,7 +614,7 @@ double ReduceMethodCost::GetForwardCommCost(const std::vector& input return result; } -double ReduceMethodCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double ReduceMethodCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_[0]) { @@ -634,8 +637,8 @@ double ReduceMethodCost::GetBackwardCommCost(const std::vector& inpu return result; } -double ReduceMethodCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMethodCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -656,8 +659,8 @@ double ReduceMethodCost::GetForwardComputationCost(const std::vector return result; } -double ReduceMeanCost::GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const { +double ReduceMeanCost::GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const { double result = 0.0; TensorInfo input0 = inputs[0]; TensorInfo output0 = outputs[0]; @@ -678,7 +681,7 @@ double ReduceMeanCost::GetForwardComputationCost(const std::vector& return result; } -double DropOutCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double DropOutCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { if (inputs.empty()) { return 0.0; @@ -689,13 +692,14 @@ double DropOutCost::GetForwardComputationCost(const std::vector& inp } // return the per device communication cost in the forward phase. -double GatherV2Cost::GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const { +double GatherV2Cost::GetForwardCommCost(const std::vector &, const std::vector &, + int32_t) const { // GatherV2Cost does not need communication in the forward phase return 0.0; } // return the per device communication cost in the backward phase. -double GatherV2Cost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double GatherV2Cost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; CheckGlobalDeviceManager(); @@ -721,7 +725,7 @@ double GatherV2Cost::GetBackwardCommCost(const std::vector& inputs, return result; } -double GatherV2Cost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double GatherV2Cost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { // In forward phase, the computation cost = slice(A) + slice(B) Shape input0_slice_shape = inputs[0].slice_shape(); @@ -731,12 +735,12 @@ double GatherV2Cost::GetForwardComputationCost(const std::vector& in return result; } -double GatherV2Cost::GetBackwardComputationCost(const std::vector&, const std::vector&, +double GatherV2Cost::GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const { return 0.0; } -double LayerNormCost::GetBackwardCommCost(const std::vector& inputs, const std::vector&, +double LayerNormCost::GetBackwardCommCost(const std::vector &inputs, const std::vector &, int32_t stage_id) const { double result = 0.0; if (is_parameter_.size() != inputs.size()) { @@ -769,7 +773,7 @@ double LayerNormCost::GetBackwardCommCost(const std::vector& inputs, return result; } -double LayerNormCost::GetForwardComputationCost(const std::vector& inputs, const std::vector&, +double LayerNormCost::GetForwardComputationCost(const std::vector &inputs, const std::vector &, int32_t) const { double result = 0.0; if (inputs_type_lengths_.size() != inputs.size()) { diff --git a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h index 37b054aa98..a243f8adfa 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h +++ b/mindspore/ccsrc/parallel/auto_parallel/operator_costmodel.h @@ -63,31 +63,31 @@ class OperatorCost { } virtual ~OperatorCost() = default; - void set_is_parameter(const std::vector& is_parameter); - void set_is_parameter_involve(const std::vector&); + void set_is_parameter(const std::vector &is_parameter); + void set_is_parameter_involve(const std::vector &); void set_output_parameter_involve(int); - void SetInputAndOutputTypeLength(const std::vector& input_lengths, const std::vector& output_lengths); + void SetInputAndOutputTypeLength(const std::vector &input_lengths, const std::vector &output_lengths); std::vector inputs_type_lengths() const { return inputs_type_lengths_; } std::vector outputs_type_lengths() const { return outputs_type_lengths_; } // per device communication cost - virtual double GetCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; // per device computation cost - virtual double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + virtual double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const = 0; - virtual double GetForwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const = 0; - virtual double GetBackwardComputationCost(const std::vector& inputs, - const std::vector& outputs, int32_t stage_id) const = 0; + virtual double GetForwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; + virtual double GetBackwardComputationCost(const std::vector &inputs, + const std::vector &outputs, int32_t stage_id) const = 0; // per device PEAK memory cost in a training iteration // Typically, the PEAK memory cost contributed by an operator is its output (if the output is parameter-invovled), // plus necessary inputs. - virtual double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const; + virtual double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const; protected: // For each input in 'inputs_', a bool variable is true if the corresponding one is a parameter or a output of @@ -113,23 +113,23 @@ class MatMulCost : public OperatorCost { ~MatMulCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using MatMulCostPtr = std::shared_ptr; @@ -140,21 +140,21 @@ class ActivationCost : public OperatorCost { ActivationCost() : OperatorCost(false) {} ~ActivationCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ActivationCostPtr = std::shared_ptr; @@ -167,21 +167,21 @@ class SoftmaxCost : public OperatorCost { SoftmaxCost() : OperatorCost(false) {} ~SoftmaxCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t) const override; }; using SoftmaxCostPtr = std::shared_ptr; @@ -192,24 +192,24 @@ class TmpIdentityCost : public OperatorCost { TmpIdentityCost() : OperatorCost(false) {} ~TmpIdentityCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override; + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override; }; using TmpIdentityCostPtr = std::shared_ptr; @@ -219,21 +219,21 @@ class BatchParallelCost : public OperatorCost { BatchParallelCost() : OperatorCost(false) {} ~BatchParallelCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using BatchParallelCostPtr = std::shared_ptr; @@ -244,30 +244,30 @@ class VirtualDatasetCost : public OperatorCost { VirtualDatasetCost() : OperatorCost(false) {} ~VirtualDatasetCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // per device PEAK memory cost in a training iteration - double GetMemoryCost(const std::vector& inputs, const std::vector& outputs) const override { + double GetMemoryCost(const std::vector &inputs, const std::vector &outputs) const override { return 0.0; } }; @@ -279,27 +279,27 @@ class GeneratorBaseCost : public OperatorCost { GeneratorBaseCost() : OperatorCost(false) {} ~GeneratorBaseCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -313,23 +313,23 @@ class PReLUCost : public OperatorCost { ~PReLUCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using PReLUCostPtr = std::shared_ptr; @@ -341,23 +341,23 @@ class OneHotCost : public OperatorCost { ~OneHotCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using OneHotCostPtr = std::shared_ptr; @@ -369,23 +369,23 @@ class SoftmaxCrossEntropyWithLogitsCost : public OperatorCost { ~SoftmaxCrossEntropyWithLogitsCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using SoftmaxCrossEntropyWithLogitsCostPtr = std::shared_ptr; @@ -398,27 +398,27 @@ class ReshapeCost : public OperatorCost { ~ReshapeCost() override = default; // per device communication cost - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; // per device computation cost - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ReshapeCostPtr = std::shared_ptr; @@ -429,22 +429,22 @@ class ArithmeticCost : public OperatorCost { ArithmeticCost() : OperatorCost(false) {} ~ArithmeticCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ArithmeticCostPtr = std::shared_ptr; @@ -457,21 +457,21 @@ class ReduceMethodCost : public OperatorCost { ReduceMethodCost() : OperatorCost(true) {} ~ReduceMethodCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -488,7 +488,7 @@ class ReduceMeanCost : public ReduceMethodCost { ReduceMeanCost() : ReduceMethodCost(true) {} ~ReduceMeanCost() override = default; - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; }; using ReduceMeanCostPtr = std::shared_ptr; @@ -499,27 +499,27 @@ class GetNextCost : public OperatorCost { GetNextCost() : OperatorCost(false) {} ~GetNextCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } // Inputs vector is empty for generator ops. - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } // Generator ops don't have backward steps. - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -532,23 +532,23 @@ class DropOutCost : public OperatorCost { DropOutCost() : OperatorCost(true) {} ~DropOutCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -562,21 +562,21 @@ class LayerNormCost : public OperatorCost { LayerNormCost() : OperatorCost(true) {} ~LayerNormCost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector&, const std::vector&, int32_t) const override { + double GetForwardCommCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } - double GetBackwardCommCost(const std::vector&, const std::vector&, int32_t) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &, const std::vector &, int32_t) const override; + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector&, const std::vector&, + double GetForwardComputationCost(const std::vector &, const std::vector &, int32_t) const override; - double GetBackwardComputationCost(const std::vector&, const std::vector&, + double GetBackwardComputationCost(const std::vector &, const std::vector &, int32_t) const override { return 0.0; } @@ -590,21 +590,21 @@ class GatherV2Cost : public OperatorCost { GatherV2Cost() : OperatorCost(true) {} ~GatherV2Cost() override = default; - double GetCommCost(const std::vector& inputs, const std::vector& outputs, + double GetCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardCommCost(inputs, outputs, stage_id) + GetBackwardCommCost(inputs, outputs, stage_id); } - double GetForwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardCommCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardCommCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override { return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id); } - double GetForwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetForwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t stage_id) const override; - double GetBackwardComputationCost(const std::vector& inputs, const std::vector& outputs, + double GetBackwardComputationCost(const std::vector &inputs, const std::vector &outputs, int32_t) const override; }; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc index 44d3642b9c..6b438cb670 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.cc @@ -35,7 +35,7 @@ const TensorParam MakeTensor(int n, int c, int h, int w) { new_tensor.tensor_shape.shape_c = c; new_tensor.tensor_shape.shape_h = h; new_tensor.tensor_shape.shape_w = w; - const TensorParam& tensor = new_tensor; + const TensorParam &tensor = new_tensor; return tensor; } @@ -71,7 +71,7 @@ Graph::NodeType MakeNewOperator(std::vector> ops, return NewOp; } -TensorParam Fill2DTensor(const std::vector>& ops, const size_t iter_ops, +TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor) { if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); @@ -94,7 +94,7 @@ TensorParam Fill2DTensor(const std::vector>& ops, return NewTensor.tensor_parm; } -OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor) { for (size_t iter_input_tensors = 0; iter_input_tensors < ops[iter_ops]->inputs_tensor_info().size(); iter_input_tensors++) { @@ -118,7 +118,7 @@ OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, const size_t iter_input_tensors, Graph::NodeType NewTensor) { if (NewTensor.apply.op_type == OperatorType::kRecMatMul) { auto attrs = ops[iter_ops]->attrs(); @@ -145,8 +145,8 @@ TensorParam Complete2DInputs(const std::vector>& o return NewTensor.apply.arguments[iter_input_tensors]; } -std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names) { +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names) { std::shared_ptr graph(new Graph); if (ops.size() > SIZE_MAX / 2) { MS_LOG(EXCEPTION) << "Total number of operators is bigger than " << SIZE_MAX / 2; @@ -161,7 +161,7 @@ std::shared_ptr ParseGraph(const std::vector>& input_tensor_names, std::shared_ptr graph) { +void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph) { for (size_t iter_i = 0; iter_i < input_tensor_names.size(); iter_i++) { for (size_t iter_j = 1; iter_j < input_tensor_names[iter_i].size(); iter_j++) { size_t head_node_index = GetIndexInInputTensorNames(input_tensor_names, input_tensor_names[iter_i][iter_j]); @@ -173,8 +173,8 @@ void MakeEdge(const std::vector>& input_tensor_names, s } } -size_t GetIndexInInputTensorNames(const std::vector>& input_tensor_name, - const std::string& input_name) { +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_name, + const std::string &input_name) { for (size_t index = 0; index < input_tensor_name.size(); index++) { if (input_tensor_name[index][0] == input_name) { return index; diff --git a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h index 0d719c33d8..ae50ced418 100644 --- a/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h +++ b/mindspore/ccsrc/parallel/auto_parallel/rec_core/rec_parse_graph.h @@ -45,22 +45,22 @@ const TensorParam MakeTensor(int n, int c, int h, int w); Graph::NodeType MakeNewOperator(std::vector> ops, size_t iter_ops); -TensorParam Fill2DTensor(const std::vector>& ops, const size_t iter_ops, +TensorParam Fill2DTensor(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor); -OperatorRec CompleteOperatorInputs(const std::vector>& ops, const size_t iter_ops, +OperatorRec CompleteOperatorInputs(const std::vector> &ops, const size_t iter_ops, Graph::NodeType NewTensor); -TensorParam Complete2DInputs(const std::vector>& ops, const size_t iter_ops, +TensorParam Complete2DInputs(const std::vector> &ops, const size_t iter_ops, const size_t iter_input_tensor, Graph::NodeType NewTensor); -std::shared_ptr ParseGraph(const std::vector>& ops, - const std::vector>& input_tensor_names); +std::shared_ptr ParseGraph(const std::vector> &ops, + const std::vector> &input_tensor_names); -void MakeEdge(const std::vector>& input_tensor_names, std::shared_ptr graph); +void MakeEdge(const std::vector> &input_tensor_names, std::shared_ptr graph); -size_t GetIndexInInputTensorNames(const std::vector>& input_tensor_names, - const std::string& input_name); +size_t GetIndexInInputTensorNames(const std::vector> &input_tensor_names, + const std::string &input_name); } // namespace parallel } // namespace mindspore #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ diff --git a/mindspore/ccsrc/parallel/context.cc b/mindspore/ccsrc/parallel/context.cc index ab216cb22c..bc4aca896b 100644 --- a/mindspore/ccsrc/parallel/context.cc +++ b/mindspore/ccsrc/parallel/context.cc @@ -73,11 +73,11 @@ void ParallelContext::set_cast_before_mirror(bool cast_before_mirror) { cast_bef void ParallelContext::set_loss_repeated_mean(bool loss_repeated_mean) { loss_repeated_mean_ = loss_repeated_mean; } -void ParallelContext::set_communication_backend(const std::string& communication_backend) { +void ParallelContext::set_communication_backend(const std::string &communication_backend) { communication_backend_ = communication_backend; } -bool ParallelContext::set_parallel_mode(const std::string& parallel_mode) { +bool ParallelContext::set_parallel_mode(const std::string ¶llel_mode) { auto iter = std::find(PARALLEL_MODE_LIST.begin(), PARALLEL_MODE_LIST.end(), parallel_mode); if (iter == PARALLEL_MODE_LIST.end()) { MS_LOG(INFO) << "Invalid parallel mode:" << parallel_mode; @@ -87,7 +87,7 @@ bool ParallelContext::set_parallel_mode(const std::string& parallel_mode) { return true; } -bool ParallelContext::set_strategy_search_mode(const std::string& strategy_search_mode) { +bool ParallelContext::set_strategy_search_mode(const std::string &strategy_search_mode) { auto iter = std::find(STRATEGY_SEARCH_MODE_LIST.begin(), STRATEGY_SEARCH_MODE_LIST.end(), strategy_search_mode); if (iter == STRATEGY_SEARCH_MODE_LIST.end()) { MS_LOG(INFO) << "Invalid strategy search mode mode: " << strategy_search_mode; diff --git a/mindspore/ccsrc/parallel/context.h b/mindspore/ccsrc/parallel/context.h index 265f5bac71..64261cb964 100644 --- a/mindspore/ccsrc/parallel/context.h +++ b/mindspore/ccsrc/parallel/context.h @@ -40,8 +40,8 @@ constexpr char RECURSIVE_PROGRAMMING[] = "recursive_programming"; class ParallelContext { public: ~ParallelContext() = default; - ParallelContext(const ParallelContext&) = delete; - ParallelContext& operator=(const ParallelContext&) = delete; + ParallelContext(const ParallelContext &) = delete; + ParallelContext &operator=(const ParallelContext &) = delete; static std::shared_ptr GetInstance(); @@ -60,13 +60,13 @@ class ParallelContext { void set_global_rank(int32_t global_rank); int32_t global_rank() const { return global_rank_; } - void set_communication_backend(const std::string& communication_backend); + void set_communication_backend(const std::string &communication_backend); std::string communication_backend() const { return communication_backend_; } - bool set_parallel_mode(const std::string& parallel_mode); + bool set_parallel_mode(const std::string ¶llel_mode); std::string parallel_mode() const { return parallel_mode_; } - bool set_strategy_search_mode(const std::string& strategy_search_mode); + bool set_strategy_search_mode(const std::string &strategy_search_mode); std::string strategy_search_mode() const { return strategy_search_mode_; } void set_parameter_broadcast(bool parameter_broadcast); diff --git a/mindspore/ccsrc/parallel/costmodel_context.h b/mindspore/ccsrc/parallel/costmodel_context.h index 23c9f7cc8d..9937483051 100644 --- a/mindspore/ccsrc/parallel/costmodel_context.h +++ b/mindspore/ccsrc/parallel/costmodel_context.h @@ -28,8 +28,8 @@ namespace parallel { class CostModelContext { public: ~CostModelContext() = default; - CostModelContext(const CostModelContext&) = delete; - CostModelContext& operator=(const CostModelContext&) = delete; + CostModelContext(const CostModelContext &) = delete; + CostModelContext &operator=(const CostModelContext &) = delete; void ResetCostModel(); void ResetAlgoParameters(); diff --git a/mindspore/ccsrc/parallel/device_manager.cc b/mindspore/ccsrc/parallel/device_manager.cc index 0b34cedc00..45628bec65 100644 --- a/mindspore/ccsrc/parallel/device_manager.cc +++ b/mindspore/ccsrc/parallel/device_manager.cc @@ -30,15 +30,15 @@ namespace mindspore { namespace parallel { DeviceManagerPtr g_device_manager = nullptr; -Stage::Stage(const std::vector& devices, int num, int rank) +Stage::Stage(const std::vector &devices, int num, int rank) : devices_(devices), number_(num), rank_(rank) { gm_ = GroupManager(); } // NOTE: '-1' indicates ERROR -int Stage::global_rank(Group* g) const { return ((g == nullptr) ? rank_ : -1); } +int Stage::global_rank(Group *g) const { return ((g == nullptr) ? rank_ : -1); } -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string& backend) { +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend) { if (device_num <= 0) { MS_LOG(ERROR) << "'device_num' must be positive."; return false; @@ -87,7 +87,7 @@ void CheckGlobalDeviceManager() { } } -int32_t GetListMemberByIndex(size_t index, const RankList& devices) { +int32_t GetListMemberByIndex(size_t index, const RankList &devices) { size_t i = 0; int32_t result = 0; if ((devices.empty()) || (index >= devices.size())) { @@ -104,7 +104,7 @@ int32_t GetListMemberByIndex(size_t index, const RankList& devices) { return result; } -std::shared_ptr GetListMemberByIndex(size_t index, const std::vector>& device_list) { +std::shared_ptr GetListMemberByIndex(size_t index, const std::vector> &device_list) { size_t i = 0; std::shared_ptr result; if ((device_list.empty()) || (index >= device_list.size())) { @@ -123,8 +123,8 @@ std::shared_ptr GetListMemberByIndex(size_t index, const std::vector DeviceManager::GetStageById(int32_t stage_id) { return res; } int32_t index = 0; - for (auto& stage : stages_) { + for (auto &stage : stages_) { if (index == stage_id) return stage; index++; } @@ -224,7 +224,7 @@ RankList DeviceManager::GetDeviceListByStageId(int32_t stage_id) const { << ", is out of the scope of 'stage_devices_': " << stage_devices_.size(); RankList res; int32_t index = 0; - for (auto& stage : stage_devices_) { + for (auto &stage : stage_devices_) { if (index == stage_id) { return stage; } @@ -280,19 +280,19 @@ Device DeviceManager::CreateNewDeviceByRank(int32_t rank) const { return Device( std::vector DeviceManager::CreateDeviceListByRankList(RankList ranks) { std::vector dev_list; - for (auto& rank : ranks) { + for (auto &rank : ranks) { Device one = CreateNewDeviceByRank(rank); dev_list.push_back(one); } return dev_list; } -DeviceManager& DeviceManager::GetInstance() { +DeviceManager &DeviceManager::GetInstance() { static DeviceManager instance = DeviceManager(); return instance; } -std::string DeviceManager::FindRankListNameByHashName(const std::string& hash_name) { +std::string DeviceManager::FindRankListNameByHashName(const std::string &hash_name) { std::string tmp = "WORLD_GROUP"; if ((hash_name == HCCL_WORLD_GROUP) || (hash_name == NCCL_WORLD_GROUP)) { return tmp; @@ -305,7 +305,7 @@ std::string DeviceManager::FindRankListNameByHashName(const std::string& hash_na return iter->second; } -std::string HashName(const std::string& origin_name) { return std::to_string(std::hash{}(origin_name)); } +std::string HashName(const std::string &origin_name) { return std::to_string(std::hash{}(origin_name)); } // Group name is generated using the increasing ranks of the devices. // E.g. the devices' ranks are '<0, 5, 3, 7, 1>', and the generated group name @@ -343,8 +343,8 @@ std::string DeviceManager::GenerateGroupNameByRanks(RankList ranks) { // Create the group with the given devices and the given name. The GroupManager // gm_ will create a new group only if there does not exit a group with the same // name. Otherwise, let the pointer g point to that group. -Group DeviceManager::CreateGroup(const std::string& group_name, - const std::vector& devices) { +Group DeviceManager::CreateGroup(const std::string &group_name, + const std::vector &devices) { if ((world_group() == NCCL_WORLD_GROUP) && (devices.size() != devices_.size())) { MS_LOG(EXCEPTION) << "Do not support sub group for nccl"; } @@ -354,7 +354,7 @@ Group DeviceManager::CreateGroup(const std::string& group_name, } // Create the group with only the given devices' ranks. -Group DeviceManager::CreateGroup(const RankList& dev_ranks) { +Group DeviceManager::CreateGroup(const RankList &dev_ranks) { std::unordered_set rank_set(dev_ranks.begin(), dev_ranks.end()); if (dev_ranks.size() != rank_set.size()) { MS_LOG(EXCEPTION) << "Invalid dev ranks(" << dev_ranks << "), it has the Duplicate elements in list"; diff --git a/mindspore/ccsrc/parallel/device_manager.h b/mindspore/ccsrc/parallel/device_manager.h index e87c1d740f..3afafe6a9c 100644 --- a/mindspore/ccsrc/parallel/device_manager.h +++ b/mindspore/ccsrc/parallel/device_manager.h @@ -53,13 +53,13 @@ class Stage { explicit Stage(std::vector devices) : devices_(std::move(devices)), number_(0), rank_(0) { gm_ = GroupManager(); } - Stage(const std::vector& devices, int num, int rank); + Stage(const std::vector &devices, int num, int rank); ~Stage() = default; int GetStageNum() const { return number_; } size_t GetDevicesNum() const { return devices_.size(); } std::vector GetDevicesList() { return devices_; } - int global_rank(Group* g) const; + int global_rank(Group *g) const; private: std::vector devices_; @@ -70,11 +70,11 @@ class Stage { // This method is used for initializing the global DeviceManager 'g_device_manager', // arguments including 'device_num' and 'global_rank' -bool InitDevice(int32_t device_num, int32_t global_rank, const std::string& backend); +bool InitDevice(int32_t device_num, int32_t global_rank, const std::string &backend); void CheckGlobalDeviceManager(); -std::string HashName(const std::string& rank_list_name); +std::string HashName(const std::string &rank_list_name); class DeviceManager { // This class is used to manage the abstract devices, including group-related and stage-related management. @@ -82,9 +82,9 @@ class DeviceManager { DeviceManager() : local_rank_(0), global_rank_(0), stage_num_(0) { gm_ = GroupManager(); } ~DeviceManager() = default; - Status Init(const RankList& devices, int32_t local_device, const RankList& stage_map, const std::string& backend); + Status Init(const RankList &devices, int32_t local_device, const RankList &stage_map, const std::string &backend); - static DeviceManager& GetInstance(); + static DeviceManager &GetInstance(); RankList GetDeviceListByStageId(int32_t stage_id) const; RankList global_device_list(int32_t stage_id, int32_t rank, int32_t split_num) const; @@ -92,8 +92,8 @@ class DeviceManager { std::vector CreateDeviceListByRankList(RankList ranks); std::string GenerateGroupNameByRanks(RankList dev_ranks); - Group CreateGroup(const std::string& group_name, const std::vector& devices); - Group CreateGroup(const RankList& dev_ranks); + Group CreateGroup(const std::string &group_name, const std::vector &devices); + Group CreateGroup(const RankList &dev_ranks); std::shared_ptr GetStageById(int32_t stage_id); size_t DeviceNum() const { return devices_.size(); } @@ -105,7 +105,7 @@ class DeviceManager { void set_global_rank(int32_t global_rank) { global_rank_ = global_rank; } void Clear(); std::string world_group() const { return gm_.world_group(); } - std::string FindRankListNameByHashName(const std::string& hash_name); + std::string FindRankListNameByHashName(const std::string &hash_name); private: std::vector> devices_; diff --git a/mindspore/ccsrc/parallel/device_matrix.cc b/mindspore/ccsrc/parallel/device_matrix.cc index 3fdc3dd15a..3c9467a223 100644 --- a/mindspore/ccsrc/parallel/device_matrix.cc +++ b/mindspore/ccsrc/parallel/device_matrix.cc @@ -53,7 +53,7 @@ Status DeviceMatrix::CreateGroupList() { return Status::SUCCESS; } -Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) { +Status DeviceMatrix::GetDevicesAlongDim(const uint32_t &dim, RankList *devices) { if (dim >= dev_shape_.size()) { MS_LOG(EXCEPTION) << "The dimension " << dim << " is out of the size of the device shape!"; } @@ -78,7 +78,7 @@ Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) for (int32_t i = 0; i < step; i++) { local_group_list.push_back(group); - (void)std::for_each(group.begin(), group.end(), [](int32_t& a) { a++; }); + (void)std::for_each(group.begin(), group.end(), [](int32_t &a) { a++; }); } // higher than dim @@ -88,19 +88,19 @@ Status DeviceMatrix::GetDevicesAlongDim(const uint32_t& dim, RankList* devices) // search rank int32_t target = rank_; for (int32_t i = 0; i < len; i++) { - for (RankList& temp : local_group_list) { + for (RankList &temp : local_group_list) { if (std::any_of(temp.begin(), temp.end(), [target](int32_t a) { return a == target; })) { *devices = temp; return Status::SUCCESS; } - (void)std::for_each(temp.begin(), temp.end(), [step](int32_t& a) { a = a + step; }); + (void)std::for_each(temp.begin(), temp.end(), [step](int32_t &a) { a = a + step; }); } } MS_LOG(ERROR) << "Can't find groups for rank" << rank_ << " in device list!"; return Status::FAILED; } -Shape ConvertRankToCoordinate(int32_t rank, const Shape& dev_shape) { +Shape ConvertRankToCoordinate(int32_t rank, const Shape &dev_shape) { Shape dev_coordinate; for (size_t i = 0; i < dev_shape.size(); ++i) { int32_t size = dev_shape[dev_shape.size() - i - 1]; @@ -115,8 +115,8 @@ Shape ConvertRankToCoordinate(int32_t rank, const Shape& dev_shape) { return dev_coordinate; } -Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* rank_list) { - for (auto& element : tensor_map) { +Status DeviceMatrix::GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list) { + for (auto &element : tensor_map) { // -1 means the corresponding dimension is not split. if (element == MAP_NONE) { continue; @@ -127,10 +127,10 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* ra } Shape current_rank_coordinate = ConvertRankToCoordinate(rank_, dev_shape_); - for (auto& tmp_rank : dev_list_) { + for (auto &tmp_rank : dev_list_) { Shape tmp_rank_coordinate = ConvertRankToCoordinate(tmp_rank, dev_shape_); bool matched = true; - for (auto& map : tensor_map) { + for (auto &map : tensor_map) { if (map == MAP_NONE) { continue; } @@ -148,7 +148,7 @@ Status DeviceMatrix::GetDevicesByTensorMap(const Shape& tensor_map, RankList* ra return SUCCESS; } -std::string ShapeToString(const Shape& shape) { +std::string ShapeToString(const Shape &shape) { std::string str = "["; for (size_t i = 0; i < shape.size(); ++i) { str += std::to_string(shape[i]); @@ -159,9 +159,9 @@ std::string ShapeToString(const Shape& shape) { return str + "]"; } -std::string ListToString(const std::vector& list) { +std::string ListToString(const std::vector &list) { std::string str = "["; - for (auto& element : list) { + for (auto &element : list) { str += std::to_string(element) + ", "; } return str + "]"; diff --git a/mindspore/ccsrc/parallel/device_matrix.h b/mindspore/ccsrc/parallel/device_matrix.h index a912000604..236a7fad08 100644 --- a/mindspore/ccsrc/parallel/device_matrix.h +++ b/mindspore/ccsrc/parallel/device_matrix.h @@ -37,8 +37,8 @@ class DeviceMatrix { ~DeviceMatrix() = default; std::vector group_list() const { return group_list_; } Status CreateGroupList(); - Status GetDevicesByTensorMap(const Shape& tensor_map, RankList* rank_list); - Status GetDevicesAlongDim(const uint32_t& dim, RankList* devices); + Status GetDevicesByTensorMap(const Shape &tensor_map, RankList *rank_list); + Status GetDevicesAlongDim(const uint32_t &dim, RankList *devices); private: int32_t rank_ = -1; @@ -48,8 +48,8 @@ class DeviceMatrix { std::vector group_list_; }; -std::string ShapeToString(const Shape& shape); -std::string ListToString(const std::vector& list); +std::string ShapeToString(const Shape &shape); +std::string ListToString(const std::vector &list); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index bad947687d..42ba42cf8a 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -28,28 +28,28 @@ namespace mindspore { namespace parallel { #define REGISTER(className) \ - OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs& attrs) { \ + OperatorInfoPtr objectCreator##className(std::string name, Shapes in, Shapes out, PrimitiveAttrs &attrs) { \ return std::make_shared(name, in, out, attrs); \ } \ RegisterAction className##Register(#className, (CreatFn)objectCreator##className); -typedef OperatorInfoPtr (*CreatFn)(const std::string& name, const Shapes& shape_in, const Shapes shape_out, - const PrimitiveAttrs& attrs); +typedef OperatorInfoPtr (*CreatFn)(const std::string &name, const Shapes &shape_in, const Shapes shape_out, + const PrimitiveAttrs &attrs); class DynCreator { public: ~DynCreator() = default; // creat static singleton dyn_creator instance - static DynCreator& Instance() { + static DynCreator &Instance() { static DynCreator fac = DynCreator(); return fac; } // register void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } // creator - OperatorInfoPtr Creat(const std::string& name, const Shapes& shape_in, const Shapes& shape_out, - const PrimitiveAttrs& attrs, size_t count) { + OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, + const PrimitiveAttrs &attrs, size_t count) { std::string op_name = name + std::to_string(count); auto iter = Function_map_.find(name); if (iter == Function_map_.end()) { @@ -66,7 +66,7 @@ class DynCreator { class RegisterAction { public: - RegisterAction(const std::string& name, CreatFn creatfn) : name_(name) { + RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { DynCreator::Instance().Regist(name, creatfn); } ~RegisterAction() = default; diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc index 43df9fe802..f5f0fe85cb 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.cc @@ -25,7 +25,7 @@ using mindspore::tensor::Tensor; namespace mindspore { namespace parallel { -std::string GetOpPythonPath(const OperatorName& op_name) { +std::string GetOpPythonPath(const OperatorName &op_name) { // almost all ops are defined in two main paths const std::string ops_module = OP_PATH; py::module mod = py::module::import(common::SafeCStr(ops_module)); @@ -35,7 +35,7 @@ std::string GetOpPythonPath(const OperatorName& op_name) { return ops_module; } -ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name, const std::string& instance_name) { +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name) { std::string op_path = GetOpPythonPath(op_name); py::module mod = py::module::import(common::SafeCStr(op_path)); if (!py::hasattr(mod, common::SafeCStr(op_name))) { @@ -44,7 +44,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name } std::vector arg_list; (void)std::transform(attrs.begin(), attrs.end(), std::back_inserter(arg_list), - [](const Attr& attr) { return ValuePtrToPyData(attr.second); }); + [](const Attr &attr) { return ValuePtrToPyData(attr.second); }); py::object obj = parse::python_adapter::CallPyFn(GET_OP_FUNCTION_PATH, GET_OP_FUNCTION, op_name, op_path, instance_name, arg_list); ValuePtr op_instance = nullptr; @@ -56,7 +56,7 @@ ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name return op_instance; } -AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr& value_ptr) { +AnfNodePtr ValuePtrToAnfNodePtr(const ValuePtr &value_ptr) { auto value_node = NewValueNode(value_ptr); MS_EXCEPTION_IF_NULL(value_node); return value_node->cast(); @@ -85,7 +85,7 @@ AnfNodePtr CreatInt32Imm(int32_t value) { return ValuePtrToAnfNodePtr(value_ptr); } -std::string GetInstanceNameByCNode(const CNodePtr& cnode) { +std::string GetInstanceNameByCNode(const CNodePtr &cnode) { PrimitivePtr prim = GetValueNode(cnode->input(0)); if (!prim) { MS_LOG(EXCEPTION) << "The first input of the cnode is not a PrimitivePtr."; @@ -94,7 +94,7 @@ std::string GetInstanceNameByCNode(const CNodePtr& cnode) { return HashInstanceName(instance_name); } -std::string HashInstanceName(const std::string& name) { +std::string HashInstanceName(const std::string &name) { auto using_hash_name = common::GetEnv(USING_HASH_NAME); std::string instance_name; if ((using_hash_name.empty()) || (using_hash_name == "on")) { @@ -105,7 +105,7 @@ std::string HashInstanceName(const std::string& name) { return instance_name; } -Status GenerateGraph::Init(const CNodePtr& cnode) { +Status GenerateGraph::Init(const CNodePtr &cnode) { if (!cnode) { MS_LOG(ERROR) << "Init:cnode is nullptr"; return FAILED; @@ -133,7 +133,7 @@ Status GenerateGraph::Init(const CNodePtr& cnode) { return SUCCESS; } -AnfNodePtr GenerateGraph::PushBack(const std::vector& inputs) { +AnfNodePtr GenerateGraph::PushBack(const std::vector &inputs) { CNodePtr cnode = func_graph_->NewCNode(inputs); // using NewCNode to creat anfnode MS_EXCEPTION_IF_NULL(cnode); cnode->set_scope(scope_); @@ -146,7 +146,7 @@ AnfNodePtr GenerateGraph::PushBack(const std::vector& inputs) { return new_anf_node_ptr; } -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name, const OperatorAttrs& attrs) { +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs) { name_idx_++; ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + op_name + std::to_string(name_idx_)); if (pyop_instance == nullptr) { @@ -156,7 +156,7 @@ AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name, const OperatorA return value_node->cast(); } -AnfNodePtr GenerateGraph::NewOpInst(const OperatorName& op_name) { +AnfNodePtr GenerateGraph::NewOpInst(const OperatorName &op_name) { name_idx_++; OperatorAttrs attrs; ValuePtr pyop_instance = CreatOpInstance(attrs, op_name, instance_name_base_ + std::to_string(name_idx_)); diff --git a/mindspore/ccsrc/parallel/graph_util/generate_graph.h b/mindspore/ccsrc/parallel/graph_util/generate_graph.h index c829e67b6a..d5535c7dc2 100644 --- a/mindspore/ccsrc/parallel/graph_util/generate_graph.h +++ b/mindspore/ccsrc/parallel/graph_util/generate_graph.h @@ -33,25 +33,25 @@ namespace mindspore { namespace parallel { #define USING_HASH_NAME "USING_HASH_NAME" // Get the operator's path where the operator has be defined -std::string GetOpPythonPath(const OperatorName& op_name); +std::string GetOpPythonPath(const OperatorName &op_name); // Init python operator Instance -ValuePtr CreatOpInstance(const OperatorAttrs& attrs, const OperatorName& op_name, const std::string& instance_name); +ValuePtr CreatOpInstance(const OperatorAttrs &attrs, const OperatorName &op_name, const std::string &instance_name); AnfNodePtr CreatTypeInt(int32_t value); AnfNodePtr CreatInt32Imm(int32_t value); AnfNodePtr CreateInt32Tensor(int32_t value); -std::string HashInstanceName(const std::string& name); +std::string HashInstanceName(const std::string &name); class GenerateGraph { public: GenerateGraph() : name_idx_(0) {} - Status Init(const CNodePtr& cnode); + Status Init(const CNodePtr &cnode); ~GenerateGraph() = default; AnfNodePtr virtual_input_node() { return virtual_input_node_; } - AnfNodePtr NewOpInst(const OperatorName& op_name, const OperatorAttrs& attrs); - AnfNodePtr NewOpInst(const OperatorName& op_name); - AnfNodePtr PushBack(const std::vector& inputs); + AnfNodePtr NewOpInst(const OperatorName &op_name, const OperatorAttrs &attrs); + AnfNodePtr NewOpInst(const OperatorName &op_name); + AnfNodePtr PushBack(const std::vector &inputs); private: CNodePtr cnode_; diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc index 3006cb7680..cbffc10e70 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.cc @@ -29,7 +29,7 @@ namespace mindspore { namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr& graph) { +py::dict GetParameterLayout(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; std::vector graph_params = graph->parameters(); @@ -50,7 +50,7 @@ py::dict GetParameterLayout(const FuncGraphPtr& graph) { return dict; } -py::dict GetCNodeStrategy(const FuncGraphPtr& graph) { +py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; auto ret = graph->get_return(); @@ -75,7 +75,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr& graph) { return dict; } -py::dict GetAllreduceFusion(const FuncGraphPtr& graph) { +py::dict GetAllreduceFusion(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); py::dict dict; auto allreduce_prim_list = FindPrimtive(graph, ALL_REDUCE); diff --git a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h index 78f597b213..e21b81a557 100644 --- a/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h +++ b/mindspore/ccsrc/parallel/graph_util/get_parallel_info.h @@ -23,9 +23,9 @@ namespace mindspore { namespace parallel { -py::dict GetParameterLayout(const FuncGraphPtr& graph); -py::dict GetCNodeStrategy(const FuncGraphPtr& graph); -py::dict GetAllreduceFusion(const FuncGraphPtr& graph); +py::dict GetParameterLayout(const FuncGraphPtr &graph); +py::dict GetCNodeStrategy(const FuncGraphPtr &graph); +py::dict GetAllreduceFusion(const FuncGraphPtr &graph); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.cc b/mindspore/ccsrc/parallel/graph_util/graph_info.cc index 46c9a37960..175413c0fd 100644 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/graph_info.cc @@ -24,12 +24,12 @@ namespace mindspore { namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr& graph, const std::string& name) { +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name) { AnfNodePtr ret = graph->get_return(); MS_EXCEPTION_IF_NULL(ret); std::vector all_nodes = DeepScopedGraphSearch(ret); std::vector prim_list; - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { if (!IsValueNode(node)) { continue; } @@ -44,7 +44,7 @@ std::vector FindPrimtive(const FuncGraphPtr& graph, const std::str return prim_list; } -void DumpGraph(const FuncGraphPtr& root, const std::string& name) { +void DumpGraph(const FuncGraphPtr &root, const std::string &name) { if (MsContext::GetInstance()->save_graphs_flag()) { draw::Draw(name + ".dot", root); DumpIR(name + ".ir", root); diff --git a/mindspore/ccsrc/parallel/graph_util/graph_info.h b/mindspore/ccsrc/parallel/graph_util/graph_info.h index 96deab2906..de800f0981 100644 --- a/mindspore/ccsrc/parallel/graph_util/graph_info.h +++ b/mindspore/ccsrc/parallel/graph_util/graph_info.h @@ -24,8 +24,8 @@ namespace mindspore { namespace parallel { -std::vector FindPrimtive(const FuncGraphPtr& graph, const std::string& name); -void DumpGraph(const FuncGraphPtr& root, const std::string& name); +std::vector FindPrimtive(const FuncGraphPtr &graph, const std::string &name); +void DumpGraph(const FuncGraphPtr &root, const std::string &name); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.cc b/mindspore/ccsrc/parallel/graph_util/node_info.cc index b2ce8ba432..c085d71240 100644 --- a/mindspore/ccsrc/parallel/graph_util/node_info.cc +++ b/mindspore/ccsrc/parallel/graph_util/node_info.cc @@ -23,13 +23,13 @@ namespace mindspore { namespace parallel { -std::string ParameterName(const AnfNodePtr& node_ptr) { +std::string ParameterName(const AnfNodePtr &node_ptr) { auto para_ptr = node_ptr->cast(); MS_EXCEPTION_IF_NULL(para_ptr); return para_ptr->name(); } -bool ParameterRequireGrad(const AnfNodePtr& node_ptr) { +bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { auto para_ptr = node_ptr->cast(); if (para_ptr == nullptr) { return false; diff --git a/mindspore/ccsrc/parallel/graph_util/node_info.h b/mindspore/ccsrc/parallel/graph_util/node_info.h index f4f46d2149..bda268e582 100644 --- a/mindspore/ccsrc/parallel/graph_util/node_info.h +++ b/mindspore/ccsrc/parallel/graph_util/node_info.h @@ -22,9 +22,9 @@ namespace mindspore { namespace parallel { -std::string ParameterName(const AnfNodePtr& node_ptr); +std::string ParameterName(const AnfNodePtr &node_ptr); -bool ParameterRequireGrad(const AnfNodePtr& node_ptr); +bool ParameterRequireGrad(const AnfNodePtr &node_ptr); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/group_manager.h b/mindspore/ccsrc/parallel/group_manager.h index 430d2f64ed..f763d483cc 100644 --- a/mindspore/ccsrc/parallel/group_manager.h +++ b/mindspore/ccsrc/parallel/group_manager.h @@ -37,11 +37,11 @@ class Group { public: Group(); ~Group() = default; - Status Init(const std::string& name, const std::vector& devices); + Status Init(const std::string &name, const std::vector &devices); std::vector GetDevicesList() const; std::string name() const { return name_; } bool IsInThisGroup(int32_t device_rank); - Status GetIndex(size_t* index); + Status GetIndex(size_t *index); size_t GetDevNum() const { return devices_.size(); } private: @@ -54,14 +54,14 @@ class GroupManager { GroupManager(); ~GroupManager() = default; - Status CreateGroup(const std::string& name, const std::vector& devices, Group* group); - Status DestroyGroup(Group* group); + Status CreateGroup(const std::string &name, const std::vector &devices, Group *group); + Status DestroyGroup(Group *group); Status DestroyAllGroups(); - Status GetRankID(const std::string& name, unsigned int* rank_id); - Status GetRankSize(const std::string& name, unsigned int* rank_size); - Status FindGroup(const std::string& name, Group** group); + Status GetRankID(const std::string &name, unsigned int *rank_id); + Status GetRankSize(const std::string &name, unsigned int *rank_size); + Status FindGroup(const std::string &name, Group **group); std::string world_group() const { return world_group_; } - void set_world_group(const std::string& name) { world_group_ = name; } + void set_world_group(const std::string &name) { world_group_ = name; } void Clear(); private: diff --git a/mindspore/ccsrc/parallel/node_check.cc b/mindspore/ccsrc/parallel/node_check.cc index e43d03c29c..7fecd307c7 100644 --- a/mindspore/ccsrc/parallel/node_check.cc +++ b/mindspore/ccsrc/parallel/node_check.cc @@ -80,7 +80,7 @@ const std::set BLACK_LIST = {TUPLE_GETITEM, REF_TO_EMBED, STOP_GRADIENT}; -bool IsInBlackList(const PrimitivePtr& prim) { +bool IsInBlackList(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end()); } diff --git a/mindspore/ccsrc/parallel/node_check.h b/mindspore/ccsrc/parallel/node_check.h index 6e5db37069..8b628f31b1 100644 --- a/mindspore/ccsrc/parallel/node_check.h +++ b/mindspore/ccsrc/parallel/node_check.h @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -bool IsInBlackList(const PrimitivePtr& prim); +bool IsInBlackList(const PrimitivePtr &prim); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.cc b/mindspore/ccsrc/parallel/ops_info/activation_info.cc index e659759de2..6bc33677a6 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -Status Activation::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -41,7 +41,7 @@ Status Activation::SetCostUnderStrategy(const StrategyPtr& strategy) { return SUCCESS; } -Status Activation::CheckStrategy(const StrategyPtr& strategy) { +Status Activation::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -110,7 +110,7 @@ Status Activation::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; @@ -120,7 +120,7 @@ Status Activation::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status Softmax::CheckStrategy(const StrategyPtr& strategy) { +Status Softmax::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -133,7 +133,7 @@ Status Softmax::CheckStrategy(const StrategyPtr& strategy) { std::vector stra = strategy->GetInputDim(); Dimensions input_strategy = stra.at(0); - for (auto& element : axis_) { + for (auto &element : axis_) { int32_t axis_index = element; if (element < 0) { size_t input_dim = inputs_shape_.at(0).size(); @@ -176,7 +176,7 @@ Status Softmax::GetAttrs() { } std::vector value_vector = value_tuple->value(); (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(axis_), - [](const ValuePtr& value) { return static_cast(GetValue(value)); }); + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); if (axis_.empty()) { MS_LOG(ERROR) << name_ << " : The axis tuple is empty."; return FAILED; @@ -205,7 +205,7 @@ Status Softmax::GetAttrs() { return SUCCESS; } -Status Softmax::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -231,7 +231,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { is_auto_parallel_ = true; Shape input0_split; (void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1); - for (auto& element : axis_) { + for (auto &element : axis_) { int32_t axis_index = element; if (element < 0) { size_t input_dim = inputs_shape_.at(0).size(); @@ -247,7 +247,7 @@ Status Softmax::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; @@ -334,7 +334,7 @@ Status ActivationBase::InferTensorInfo() { return SUCCESS; } -Status ActivationBase::Init(const StrategyPtr& strategy) { +Status ActivationBase::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -344,7 +344,7 @@ Status ActivationBase::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status ActivationBase::InitForCostModel(const StrategyPtr& strategy) { +Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -547,7 +547,7 @@ Status ExpandDimsInfo::InferMirrorOps() { return SUCCESS; } -Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { +Status SqueezeInfo::InferAxis(const ValueTuplePtr &value_tuple) { std::vector axis; auto axis_list = value_tuple->value(); if (inputs_shape_.empty()) { @@ -568,7 +568,7 @@ Status SqueezeInfo::InferAxis(const ValueTuplePtr& value_tuple) { } // convert negative axis to positive. - for (auto& dim : axis_list) { + for (auto &dim : axis_list) { if (!dim->isa()) { MS_LOG(ERROR) << name_ << ": The type of axis is not int"; return FAILED; @@ -595,7 +595,7 @@ Status SqueezeInfo::GetAttrs() { return SUCCESS; } -Status SqueezeInfo::InferReplaceOps(const StrategyPtr& strategy) { +Status SqueezeInfo::InferReplaceOps(const StrategyPtr &strategy) { Attr attr = std::make_pair(AXIS, axis_); OperatorAttrs attrs = {attr}; OperatorParams params; @@ -689,7 +689,7 @@ Status SqueezeInfo::InferTensorInfo() { return SUCCESS; } -Status SqueezeInfo::Init(const StrategyPtr& strategy) { +Status SqueezeInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; } diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index 887be5ea33..a71c6b6df7 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -31,13 +31,13 @@ namespace mindspore { namespace parallel { class ActivationBase : public OperatorInfo { public: - ActivationBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + ActivationBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ActivationBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; protected: Status InferMirrorOps() override; @@ -49,21 +49,21 @@ class ActivationBase : public OperatorInfo { class Activation : public ActivationBase { public: - Activation(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + Activation(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Activation() override = default; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; }; class ActivationInfo : public Activation { public: - ActivationInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ActivationInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~ActivationInfo() override = default; @@ -73,8 +73,8 @@ class ActivationInfo : public Activation { class ActivationOther : public Activation { public: - ActivationOther(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ActivationOther(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~ActivationOther() override = default; @@ -84,31 +84,31 @@ class ActivationOther : public Activation { class GeluInfo : public ActivationOther { public: - GeluInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GeluInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~GeluInfo() override = default; }; class TanhInfo : public ActivationOther { public: - TanhInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TanhInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~TanhInfo() override = default; }; class Softmax : public ActivationBase { public: - explicit Softmax(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + explicit Softmax(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~Softmax() override = default; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; private: @@ -117,32 +117,32 @@ class Softmax : public ActivationBase { class SoftmaxInfo : public Softmax { public: - SoftmaxInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Softmax(name, inputs_shape, outputs_shape, attrs) {} ~SoftmaxInfo() override = default; }; class LogSoftmaxInfo : public Softmax { public: - LogSoftmaxInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LogSoftmaxInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Softmax(name, inputs_shape, outputs_shape, attrs) {} ~LogSoftmaxInfo() override = default; }; class ReLUInfo : public ActivationOther { public: - ReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ReLUInfo() override = default; }; class CastInfo : public ActivationOther { public: - CastInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + CastInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~CastInfo() override = default; @@ -152,23 +152,23 @@ class CastInfo : public ActivationOther { class SqrtInfo : public ActivationOther { public: - SqrtInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SqrtInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SqrtInfo() override = default; }; class NegInfo : public ActivationOther { public: - NegInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + NegInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~NegInfo() override = default; }; class ExpandDimsInfo : public ActivationOther { public: - ExpandDimsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ExpandDimsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ExpandDimsInfo() override = default; @@ -187,18 +187,18 @@ class ExpandDimsInfo : public ActivationOther { class SqueezeInfo : public ActivationOther { public: - SqueezeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SqueezeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SqueezeInfo() override = default; protected: - Status InferAxis(const ValueTuplePtr& value_tuple); + Status InferAxis(const ValueTuplePtr &value_tuple); Status GetAttrs() override; - Status InferReplaceOps(const StrategyPtr& strategy); + Status InferReplaceOps(const StrategyPtr &strategy); Status InferTensorMap() override; Status InferTensorInfo() override; - Status Init(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; private: ValueTuplePtr axis_; @@ -206,8 +206,8 @@ class SqueezeInfo : public ActivationOther { class SquareInfo : public ActivationOther { public: - SquareInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SquareInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SquareInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h index 78dfc23803..27caacc30c 100644 --- a/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h +++ b/mindspore/ccsrc/parallel/ops_info/arithmetic_info.h @@ -31,92 +31,92 @@ namespace mindspore { namespace parallel { class ArithmeticBase : public OperatorInfo { public: - ArithmeticBase(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + ArithmeticBase(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, cost) {} ~ArithmeticBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; void ReComputeBatchSplitFlagList() override; protected: Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); Shapes InferExpendShape(); }; class SubInfo : public ArithmeticBase { public: - SubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + SubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SubInfo() override = default; }; class TensorAddInfo : public ArithmeticBase { public: - TensorAddInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TensorAddInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TensorAddInfo() override = default; }; class MulInfo : public ArithmeticBase { public: - MulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + MulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MulInfo() override = default; }; class DivInfo : public ArithmeticBase { public: - DivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + DivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DivInfo() override = default; }; class RealDivInfo : public ArithmeticBase { public: - RealDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + RealDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~RealDivInfo() override = default; }; class FloorDivInfo : public ArithmeticBase { public: - FloorDivInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + FloorDivInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~FloorDivInfo() override = default; }; class PowInfo : public ArithmeticBase { public: - PowInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + PowInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PowInfo() override = default; }; class GreaterInfo : public ArithmeticBase { public: - GreaterInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GreaterInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~GreaterInfo() override = default; }; class AssignSubInfo : public ArithmeticBase { public: - AssignSubInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + AssignSubInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~AssignSubInfo() override = default; }; @@ -124,8 +124,8 @@ class AssignSubInfo : public ArithmeticBase { // All dimensions can be split arbitrarily, but the split method of Logits should be the same as that of label. class SigmoidCrossEntropyWithLogitsInfo : public ArithmeticBase { public: - SigmoidCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SigmoidCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SigmoidCrossEntropyWithLogitsInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc index 9d356cd573..dac3b0a675 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status BatchParallelInfo::CheckStrategy(const StrategyPtr& strategy) { +Status BatchParallelInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -161,7 +161,7 @@ Status BatchParallelInfo::InferTensorInfo() { Status BatchParallelInfo::GetAttrs() { return SUCCESS; } -Status BatchParallelInfo::Init(const StrategyPtr& strategy) { +Status BatchParallelInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -170,7 +170,7 @@ Status BatchParallelInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status BatchParallelInfo::InitForCostModel(const StrategyPtr& strategy) { +Status BatchParallelInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -184,7 +184,7 @@ Status BatchParallelInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status BatchParallelInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h index 4cedb9b7b8..db6cb206d5 100644 --- a/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h +++ b/mindspore/ccsrc/parallel/ops_info/batch_parallel_info.h @@ -29,22 +29,22 @@ namespace mindspore { namespace parallel { class BatchParallelInfo : public OperatorInfo { public: - BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs, OperatorCostPtr cost) + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs, OperatorCostPtr cost) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, cost), dev_num_(1) {} - BatchParallelInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BatchParallelInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), dev_num_(1) {} ~BatchParallelInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; @@ -60,8 +60,8 @@ class BatchParallelInfo : public OperatorInfo { class SparseSoftmaxCrossEntropyWithLogitsInfo : public BatchParallelInfo { public: - SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, - const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + SparseSoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, + const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : BatchParallelInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~SparseSoftmaxCrossEntropyWithLogitsInfo() override = default; void ReComputeBatchSplitFlagList() override; diff --git a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h index e792858338..37f555a258 100644 --- a/mindspore/ccsrc/parallel/ops_info/bias_add_info.h +++ b/mindspore/ccsrc/parallel/ops_info/bias_add_info.h @@ -32,26 +32,26 @@ namespace mindspore { namespace parallel { class BiasAddInfo : public OperatorInfo { public: - BiasAddInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BiasAddInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~BiasAddInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; void ReComputeBatchSplitFlagList() override; protected: Status GetAttrs() override { return SUCCESS; } - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout, const Shape& dev_matrix_array); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout, const Shape &dev_matrix_array); }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h index 9ea496e0b0..8dd2976b04 100644 --- a/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/comparison_function_info.h @@ -30,32 +30,32 @@ namespace mindspore { namespace parallel { class EqualInfo : public ArithmeticBase { public: - EqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + EqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~EqualInfo() override = default; }; class NotEqualInfo : public ArithmeticBase { public: - NotEqualInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + NotEqualInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~NotEqualInfo() override = default; }; class MaximumInfo : public ArithmeticBase { public: - MaximumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MaximumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MaximumInfo() override = default; }; class MinimumInfo : public ArithmeticBase { public: - MinimumInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MinimumInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ArithmeticBase(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MinimumInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc index c755cc785d..87b8d15cca 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.cc @@ -32,7 +32,7 @@ namespace mindspore { namespace parallel { static int32_t SEED_NUM = 1; -Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::CheckStrategy(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null"; return FAILED; @@ -129,7 +129,7 @@ Status DropoutDoMaskInfo::InferTensorInfo() { return SUCCESS; } -Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; @@ -159,7 +159,7 @@ Status DropoutDoMaskInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy"; @@ -178,7 +178,7 @@ std::shared_ptr>> DropoutDoMaskInfo::GenerateBa return std::make_shared>>(strategy_v); } -Status DropoutDoMaskInfo::Init(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -188,7 +188,7 @@ Status DropoutDoMaskInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr& strategy) { +Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -202,7 +202,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr& cnode) { +PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) { MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE; @@ -237,7 +237,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr& cnode) { // split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape // of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation // and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask. -Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr& cnode) { +Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); PrimitivePtr prim = GetDropoutGenMaskPrim(cnode); MS_EXCEPTION_IF_NULL(prim); diff --git a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h index 3b154bd6db..c0d112f52d 100644 --- a/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h +++ b/mindspore/ccsrc/parallel/ops_info/dropout_do_mask_info.h @@ -31,20 +31,20 @@ namespace mindspore { namespace parallel { class DropoutDoMaskInfo : public OperatorInfo { public: - DropoutDoMaskInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + DropoutDoMaskInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~DropoutDoMaskInfo() override = default; - Status Init(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; - Operator GetDropoutGenMaskReplaceOp(const CNodePtr& cnode); + Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode); protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorMap() override; diff --git a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h index 84b8030f37..2172c5cd89 100644 --- a/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h +++ b/mindspore/ccsrc/parallel/ops_info/elementary_function_info.h @@ -29,37 +29,37 @@ namespace mindspore { namespace parallel { class ExpInfo : public ActivationOther { public: - ExpInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + ExpInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ExpInfo() override = default; }; class LogInfo : public ActivationOther { public: - LogInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + LogInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~LogInfo() override = default; }; class CosInfo : public ActivationOther { public: - CosInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + CosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~CosInfo() override = default; }; class ACosInfo : public ActivationOther { public: - ACosInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ACosInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~ACosInfo() override = default; }; class LogicalNotInfo : public ActivationOther { public: - LogicalNotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LogicalNotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~LogicalNotInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc index c315991849..c9e8835f35 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.cc @@ -70,7 +70,7 @@ Status GatherV2Info::GetAttrs() { return SUCCESS; } -Status GatherV2Info::CheckStrategy(const StrategyPtr& strategy) { +Status GatherV2Info::CheckStrategy(const StrategyPtr &strategy) { if (inputs_shape_.size() != GATHER_V2_INPUTS_SIZE) { MS_LOG(ERROR) << name_ << ": inputs shape size must be " << GATHER_V2_INPUTS_SIZE << ", but is " << inputs_shape_.size(); @@ -256,7 +256,7 @@ Status GatherV2Info::InferTensorSubOps() { return SUCCESS; } -Status GatherV2Info::Init(const StrategyPtr& strategy) { +Status GatherV2Info::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -270,7 +270,7 @@ Status GatherV2Info::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status GatherV2Info::InitForCostModel(const StrategyPtr& strategy) { +Status GatherV2Info::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -301,7 +301,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy"; @@ -311,7 +311,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status GatherV2Info::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h index 773d46f429..f7aeb6a0d9 100644 --- a/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h +++ b/mindspore/ccsrc/parallel/ops_info/gather_v2_info.h @@ -38,22 +38,22 @@ constexpr size_t GATHER_V2_INPUTS_VALUE_SIZE = 3; // If Index is a scalar or n-dimension vector(n > 1), the strategy corresponding to axis must be 1. class GatherV2Info : public OperatorInfo { public: - GatherV2Info(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + GatherV2Info(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared()), axis_(-1), index_size_(0), axis_strategy_(1) {} ~GatherV2Info() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; std::shared_ptr>> GenerateBatchStrategies() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; diff --git a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc index ac9acff41b..29d519fda8 100644 --- a/mindspore/ccsrc/parallel/ops_info/get_next_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/get_next_info.cc @@ -39,7 +39,7 @@ Status GetNextInfo::InferTensorMap() { return SUCCESS; } -Status GetNextInfo::InferTensorLayout(TensorLayouts* outputs_layout) { +Status GetNextInfo::InferTensorLayout(TensorLayouts *outputs_layout) { if (outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << " : The layout is null."; return FAILED; @@ -96,7 +96,7 @@ Status GetNextInfo::InferDevMatrixShape() { return SUCCESS; } -Status GetNextInfo::Init(const StrategyPtr& strategy) { +Status GetNextInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed"; return FAILED; @@ -109,7 +109,7 @@ Status GetNextInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status GetNextInfo::CheckStrategy(const StrategyPtr& strategy) { +Status GetNextInfo::CheckStrategy(const StrategyPtr &strategy) { std::vector stras = strategy->GetInputDim(); for (Dimensions stra : stras) { if (stra.size() != 0) { @@ -135,7 +135,7 @@ Status GetNextInfo::GetAttrTypes() { auto iter_cast = iter->second->cast(); MS_EXCEPTION_IF_NULL(iter_cast); auto types = iter_cast->value(); - for (auto& type : types) { + for (auto &type : types) { MS_EXCEPTION_IF_NULL(type); types_.push_back(type->ToString()); } @@ -143,7 +143,7 @@ Status GetNextInfo::GetAttrTypes() { auto iter_cast = iter->second->cast(); MS_EXCEPTION_IF_NULL(iter_cast); auto types = iter_cast->value(); - for (auto& type : types) { + for (auto &type : types) { MS_EXCEPTION_IF_NULL(type); types_.push_back(type->ToString()); } @@ -189,7 +189,7 @@ Status GetNextInfo::GetAttrs() { return SUCCESS; } -Status GetNextInfo::InferReplaceOps(const StrategyPtr&) { +Status GetNextInfo::InferReplaceOps(const StrategyPtr &) { Shapes out_shapes = outputs_shape_; for (size_t i = 0; i < out_shapes.size(); ++i) { if (dev_num_ <= 0) { @@ -214,7 +214,7 @@ Status GetNextInfo::InferReplaceOps(const StrategyPtr&) { return SUCCESS; } -Status GetNextInfo::InitForCostModel(const StrategyPtr& strategy) { +Status GetNextInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -227,7 +227,7 @@ Status GetNextInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status GetNextInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc index 2955f76506..8716997d9f 100644 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status L2NormalizeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status L2NormalizeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -111,7 +111,7 @@ Status L2NormalizeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h index 22ed5a965b..ca063d01d8 100644 --- a/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h +++ b/mindspore/ccsrc/parallel/ops_info/l2_normalize_info.h @@ -31,8 +31,8 @@ namespace mindspore { namespace parallel { class L2NormalizeInfo : public Activation { public: - L2NormalizeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + L2NormalizeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : Activation(name, inputs_shape, outputs_shape, attrs) {} ~L2NormalizeInfo() override = default; Status GenerateStrategies(int32_t stage_id) override; @@ -40,7 +40,7 @@ class L2NormalizeInfo : public Activation { protected: Status GetAttrs() override; Status InferMirrorOps() override; - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; private: int32_t axis_ = 0; // Default value = 0 diff --git a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h index c52645ade2..50117b8185 100644 --- a/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h +++ b/mindspore/ccsrc/parallel/ops_info/layer_norm_info.h @@ -38,20 +38,20 @@ constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis"; // arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add. class LayerNormInfo : public OperatorInfo { public: - LayerNormInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + LayerNormInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared(true)), begin_norm_axis_(0) {} ~LayerNormInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t) override; - Status SetCostUnderStrategy(const StrategyPtr&) override; + Status SetCostUnderStrategy(const StrategyPtr &) override; protected: Status GetAttrs() override; - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override { return SUCCESS; } Status InferTensorInfo() override; @@ -61,7 +61,7 @@ class LayerNormInfo : public OperatorInfo { Status CreateTensorMap(size_t input_index); Status CreateTensorInfo(size_t input_index); Status CreateMirrorOp(size_t input_index); - Status GenerateGammaAndBetaStrategies(const std::vector& sp_vector); + Status GenerateGammaAndBetaStrategies(const std::vector &sp_vector); Status InitShapes(); private: diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.cc b/mindspore/ccsrc/parallel/ops_info/loss_info.cc index 28ea19f120..0ba325c0cd 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.cc @@ -28,7 +28,7 @@ namespace mindspore { namespace parallel { -Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::CheckStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -152,7 +152,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::InferAsLossDivisor() { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -162,7 +162,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -205,7 +205,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << " : Successfully generated " << success << " strategy."; @@ -216,7 +216,7 @@ Status SoftmaxCrossEntropyWithLogitsInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status SoftmaxCrossEntropyWithLogitsInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { PrintStrategy(strategy); if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { diff --git a/mindspore/ccsrc/parallel/ops_info/loss_info.h b/mindspore/ccsrc/parallel/ops_info/loss_info.h index 44fe22ce90..2679c2d62b 100644 --- a/mindspore/ccsrc/parallel/ops_info/loss_info.h +++ b/mindspore/ccsrc/parallel/ops_info/loss_info.h @@ -34,20 +34,20 @@ namespace parallel { // output_0 : [a], output_1: [a, b] class SoftmaxCrossEntropyWithLogitsInfo : public OperatorInfo { public: - SoftmaxCrossEntropyWithLogitsInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + SoftmaxCrossEntropyWithLogitsInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~SoftmaxCrossEntropyWithLogitsInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; void ReComputeBatchSplitFlagList() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc index 8d1264482b..3f55efb66c 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.cc @@ -31,8 +31,8 @@ namespace mindspore { namespace parallel { -void SetDevMatrixShape(const Dimensions& mat_a_strategy, const Dimensions& mat_b_strategy, bool transpose_b, - Shape* dev_matrix_shape) { +void SetDevMatrixShape(const Dimensions &mat_a_strategy, const Dimensions &mat_b_strategy, bool transpose_b, + Shape *dev_matrix_shape) { MS_EXCEPTION_IF_NULL(dev_matrix_shape); size_t mat_a_size = mat_a_strategy.size(); size_t mat_b_size = mat_b_strategy.size(); @@ -105,7 +105,7 @@ Status MatMulBase::GetAttrs() { return SUCCESS; } -Status CheckRelevantDimension(const Dimensions& long_strategy, const Dimensions& short_strategy) { +Status CheckRelevantDimension(const Dimensions &long_strategy, const Dimensions &short_strategy) { size_t long_size = long_strategy.size(); size_t short_size = short_strategy.size(); if (long_size < short_size) { @@ -126,7 +126,7 @@ Status CheckRelevantDimension(const Dimensions& long_strategy, const Dimensions& return SUCCESS; } -Status MatMul::CheckStrategy(const StrategyPtr& strategy) { +Status MatMul::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Invalid strategy."; @@ -239,7 +239,7 @@ Status MatMulBase::InferForwardCommunication() { } // dev_matrix_shape: [a, b, c, d, e], then output strategy: [a, b, c, e]; -Dimensions GetOutputStrategy(const Shape& dev_matrix_shape, int32_t repeated_calculation_num) { +Dimensions GetOutputStrategy(const Shape &dev_matrix_shape, int32_t repeated_calculation_num) { Dimensions output_strategy = dev_matrix_shape; if (repeated_calculation_num > 1) { // move the first dimension(repeated_calc_num_) @@ -301,7 +301,7 @@ Status MatMulBase::InferTensorMap() { return SUCCESS; } -Status MatMulBase::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status MatMulBase::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { TensorLayout mat_a_layout, mat_b_layout, output_layout; if ((mat_a_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[0], inputs_shape_[0]) != SUCCESS) || (mat_b_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[1], inputs_shape_[1]) != SUCCESS) || @@ -353,7 +353,7 @@ Status MatMulBase::InferTensorInfo() { return SUCCESS; } -Status MatMulBase::Init(const StrategyPtr& strategy) { +Status MatMulBase::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << " : Init failed."; return FAILED; @@ -363,7 +363,7 @@ Status MatMulBase::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status MatMulBase::InitForCostModel(const StrategyPtr& strategy) { +Status MatMulBase::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Init for cost model failed."; @@ -377,7 +377,7 @@ Status MatMulBase::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape* const input) { +Status MatMulBase::SwapLastTwoElements(mindspore::parallel::Shape *const input) { if (input->size() < 2) { MS_LOG(ERROR) << name_ << " : The size of inputs small than 2."; return FAILED; @@ -463,7 +463,7 @@ Status MatMulBase::GenerateStrategies(int32_t stage_id) { Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, mindspore::parallel::Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, mindspore::parallel::StrategyPtr* const sp) { + size_t input1_shape_size, mindspore::parallel::StrategyPtr *const sp) { int32_t product = std::accumulate(combined_partitions.begin(), combined_partitions.end(), 1, std::multiplies()); if (!FULLY_USE_DEVICES) { if (IntToSize(product) > dev_num) { @@ -519,7 +519,7 @@ Status MatMulBase::PrepareStrategy(int32_t stage_id, size_t dev_num, return SUCCESS; } -void MatMulBase::InitTensorInfoForCost(std::vector* relica_inputs_tensor_vector) { +void MatMulBase::InitTensorInfoForCost(std::vector *relica_inputs_tensor_vector) { TensorLayout tly; if (transpose_a_) { Shape replica_input0_shape(inputs_tensor_info_[0].shape()); @@ -560,7 +560,7 @@ Status MatMulBase::CheckForTensorSliceValid() const { if (inputs_tensor_info_.empty()) { return FAILED; } - for (auto& one_input_tensor : inputs_tensor_info_) { + for (auto &one_input_tensor : inputs_tensor_info_) { auto slice_shape = one_input_tensor.slice_shape(); if ((IntToSize(slice_shape[LAST_INDEX(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0) || (IntToSize(slice_shape[SECOND_FROM_END(slice_shape.size())]) % TENSOR_SLICE_ALIGNMENT_SIZE != 0)) { @@ -570,7 +570,7 @@ Status MatMulBase::CheckForTensorSliceValid() const { return SUCCESS; } -Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status MatMulBase::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (InitForCostModel(strategy) == FAILED) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << " : Initialization under the strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/matmul_info.h b/mindspore/ccsrc/parallel/ops_info/matmul_info.h index 8a64fb7206..86a74f78f2 100644 --- a/mindspore/ccsrc/parallel/ops_info/matmul_info.h +++ b/mindspore/ccsrc/parallel/ops_info/matmul_info.h @@ -32,21 +32,21 @@ namespace mindspore { namespace parallel { class MatMulBase : public OperatorInfo { public: - MatMulBase(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MatMulBase(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~MatMulBase() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; // Generate all strategies and the corresponding cost for this MatMul operator Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; Status PrepareStrategy(int32_t stage_id, size_t dev_num, Dimensions combined_partitions, size_t input0_shape_size, - size_t input1_shape_size, StrategyPtr* sp); + size_t input1_shape_size, StrategyPtr *sp); - Status SwapLastTwoElements(Shape* shape); + Status SwapLastTwoElements(Shape *shape); protected: Status InferMirrorOps() override; @@ -54,8 +54,8 @@ class MatMulBase : public OperatorInfo { Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); - void InitTensorInfoForCost(std::vector*); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); + void InitTensorInfoForCost(std::vector *); Status CheckForTensorSliceValid() const; Status GetAttrs() override; @@ -67,26 +67,26 @@ class MatMulBase : public OperatorInfo { class MatMul : public MatMulBase { public: - MatMul(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs) + MatMul(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs) : MatMulBase(name, inputs_shape, outputs_shape, attrs) {} ~MatMul() override = default; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; }; class MatMulInfo : public MatMul { public: - MatMulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + MatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : MatMul(name, inputs_shape, outputs_shape, attrs) {} ~MatMulInfo() override = default; }; class BatchMatMulInfo : public MatMul { public: - BatchMatMulInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + BatchMatMulInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : MatMul(name, inputs_shape, outputs_shape, attrs) {} ~BatchMatMulInfo() override = default; }; diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc index e07609d3c4..2c06a1ace9 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.cc @@ -54,7 +54,7 @@ Status OneHotInfo::GetAttrs() { return SUCCESS; } -Status OneHotInfo::CheckStrategy(const StrategyPtr& strategy) { +Status OneHotInfo::CheckStrategy(const StrategyPtr &strategy) { if (inputs_shape_.size() != 3) { MS_LOG(ERROR) << name_ << ": inputs_shape_ size must be 3, but is " << inputs_shape_.size(); return FAILED; @@ -185,7 +185,7 @@ Status OneHotInfo::ExtractInputInfo() { return SUCCESS; } -Status OneHotInfo::ComputeReplaceGraph(const CNodePtr& cnode) { +Status OneHotInfo::ComputeReplaceGraph(const CNodePtr &cnode) { if (dev_matrix_shape_.back() == 1) { replace_graph_ = nullptr; return SUCCESS; @@ -222,7 +222,7 @@ Status OneHotInfo::ComputeReplaceGraph(const CNodePtr& cnode) { return SUCCESS; } -ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr& cnode) { +ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr &cnode) { if (ComputeReplaceGraph(cnode) != SUCCESS) { MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed."; return nullptr; @@ -230,7 +230,7 @@ ReplaceGraphPtr OneHotInfo::replace_graph(const CNodePtr& cnode) { return replace_graph_; } -Status OneHotInfo::Init(const StrategyPtr& strategy) { +Status OneHotInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -244,7 +244,7 @@ Status OneHotInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status OneHotInfo::InitForCostModel(const StrategyPtr& strategy) { +Status OneHotInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -276,7 +276,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; @@ -287,7 +287,7 @@ Status OneHotInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status OneHotInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/onehot_info.h b/mindspore/ccsrc/parallel/ops_info/onehot_info.h index a4f00ea093..3c8a64f954 100644 --- a/mindspore/ccsrc/parallel/ops_info/onehot_info.h +++ b/mindspore/ccsrc/parallel/ops_info/onehot_info.h @@ -31,20 +31,20 @@ namespace mindspore { namespace parallel { class OneHotInfo : public OperatorInfo { public: - OneHotInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + OneHotInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~OneHotInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; - ReplaceGraphPtr replace_graph(const CNodePtr& cnode) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; + ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override; std::shared_ptr>> GenerateBatchStrategies() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override; Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } @@ -54,7 +54,7 @@ class OneHotInfo : public OperatorInfo { Status ExtractInputInfo(); private: - Status ComputeReplaceGraph(const CNodePtr& cnode); + Status ComputeReplaceGraph(const CNodePtr &cnode); int axis_ = -1; int32_t rank_ = 0; diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.cc b/mindspore/ccsrc/parallel/ops_info/operator_info.cc index c6115a9fa6..8074f2a32e 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.cc @@ -35,7 +35,7 @@ namespace mindspore { namespace parallel { -Status CheckStrategyValue(const StrategyPtr& strategy, const Shapes& inputs_shape, bool is_auto_parallel) { +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool is_auto_parallel) { if (strategy == nullptr) { MS_LOG(ERROR) << "The strategy is null."; return FAILED; @@ -190,7 +190,7 @@ Operator CreateVirtualDivOp(int32_t div_num) { } // use for forward all reduce -Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& group) { +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group) { OperatorName operator_name = ALL_REDUCE; ValuePtr attr0_value = MakeValue(reduce_op); // ReduceOP.SUM ValuePtr attr1_value = MakeValue(group); // group @@ -209,7 +209,7 @@ Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& grou } // use for get tensor slice -Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout) { +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout) { Shape tensor_map = tensor_layout.tensor_map().array(); Shape dev_matrix_shape = tensor_layout.device_arrangement().array(); OperatorName operator_name = GET_TENSOR_SLICE; @@ -228,7 +228,7 @@ Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout) { return op; } -OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num) { +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num) { if ((dev_num == 0) || (dev_num == 1)) { MS_LOG(EXCEPTION) << "Invalid dev num: " << dev_num; } @@ -260,7 +260,7 @@ OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num) { return op_for_weight; } -Status OperatorInfo::CreateGroupByTensorMap(const Shape& tensor_map, std::vector* group) { +Status OperatorInfo::CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group) { if (group == nullptr) { MS_LOG(ERROR) << "The group is null."; return FAILED; @@ -283,7 +283,7 @@ Status OperatorInfo::CreateGroupByTensorMap(const Shape& tensor_map, std::vector return SUCCESS; } -Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector* group) { +Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector *group) { if (group == nullptr) { MS_LOG(ERROR) << "The group is null."; return FAILED; @@ -306,7 +306,7 @@ Status OperatorInfo::CreateGroupByDim(size_t axis, std::vector* group) { return SUCCESS; } -Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy) { +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy) { Shape slice_shape; if (std::any_of(strategy.begin(), strategy.end(), [](int32_t value) { return value <= 0; })) { MS_LOG(ERROR) << "Invalid strategy: " << ShapeToString(strategy) << ", the element is less than or equal to 0"; @@ -318,7 +318,7 @@ Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy) { return slice_shape; } -Status InferSliceShapeByStrategy(const Strategys& strategys, const Shapes& shapes, Shapes* slice_shapes) { +Status InferSliceShapeByStrategy(const Strategys &strategys, const Shapes &shapes, Shapes *slice_shapes) { if (slice_shapes == nullptr) { MS_LOG(ERROR) << "The slice_shapes is null."; return FAILED; @@ -357,8 +357,8 @@ Status InferSliceShapeByStrategy(const Strategys& strategys, const Shapes& shape return SUCCESS; } -Status OperatorInfo::InferSliceShape(const Strategys& inputs_strategy, const Strategys& outputs_strategy, - Shapes* inputs_slice_shape, Shapes* outputs_slice_shape) { +Status OperatorInfo::InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape) { if (inputs_slice_shape == nullptr || outputs_slice_shape == nullptr) { MS_LOG(ERROR) << "The slice_shape is null."; return FAILED; @@ -379,7 +379,7 @@ Status OperatorInfo::InferSliceShape(const Strategys& inputs_strategy, const Str } // method0: auto insert repeated_calculation_num for dev_matrix_shape when repeated_calculation_num > 1 -Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -437,7 +437,7 @@ Status OperatorInfo::InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strat } // method1: manually insert repeated_calculation_num for dev_matrix_shape in InferDevMatrixShape -Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -485,7 +485,7 @@ Status OperatorInfo::InitForCostModelWithManualRepeatCalc(const StrategyPtr& str return SUCCESS; } -Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -513,7 +513,7 @@ Status OperatorInfo::InitWithAutoRepeatCalc(const StrategyPtr& strategy) { return SUCCESS; } -Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr& strategy) { +Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr &strategy) { if (strategy == nullptr) { MS_LOG(ERROR) << name_ << ": The strategy is null."; return FAILED; @@ -543,12 +543,12 @@ Status OperatorInfo::InitWithManualRepeatCalc(const StrategyPtr& strategy) { std::vector> OperatorInfo::GetAliveSuccEdges() { std::vector> ret; - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) != std::string::npos)) { ret.push_back(edge); } } - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if ((edge->next_operator()->is_alive()) && (edge->next_operator()->name().find(RELU) == std::string::npos)) { ret.push_back(edge); } @@ -558,7 +558,7 @@ std::vector> OperatorInfo::GetAliveSuccEdges() { std::vector> OperatorInfo::GetAlivePrevEdges() { std::vector> ret; - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator()->is_alive()) { ret.push_back(edge); } @@ -566,12 +566,12 @@ std::vector> OperatorInfo::GetAlivePrevEdges() { return ret; } -void OperatorInfo::ReplacePreEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplacePreEdge: the op is null."; return; } - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator() == op) { edge = new_edge; return; @@ -580,12 +580,12 @@ void OperatorInfo::ReplacePreEdge(const std::shared_ptr& op, const MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; } -void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplaceSuccEdge: the op is null."; return; } - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if (edge->next_operator() == op) { edge = new_edge; return; @@ -594,13 +594,13 @@ void OperatorInfo::ReplaceSuccEdge(const std::shared_ptr& op, cons MS_LOG(EXCEPTION) << name_ << ": Replace edge failed: no edge has been replaced"; } -void OperatorInfo::ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplacePreEdges: the op is null."; return; } std::vector> new_pre_edges; - for (auto& edge : prev_edges_) { + for (auto &edge : prev_edges_) { if (edge->prev_operator() != op) { new_pre_edges.push_back(edge); } @@ -609,13 +609,13 @@ void OperatorInfo::ReplacePreEdges(const std::shared_ptr& op, cons prev_edges_ = new_pre_edges; } -void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge) { +void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge) { if (op == nullptr) { MS_LOG(ERROR) << name_ << ": ReplaceSuccEdges: the op is null"; return; } std::vector> new_succ_edges; - for (auto& edge : succ_edges_) { + for (auto &edge : succ_edges_) { if (edge->next_operator() != op) { new_succ_edges.push_back(edge); } @@ -625,7 +625,7 @@ void OperatorInfo::ReplaceSuccEdges(const std::shared_ptr& op, con } std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes& shapes, const std::vector& split_flag_list) { + const Shapes &shapes, const std::vector &split_flag_list) { if (shapes.size() != split_flag_list.size()) { MS_LOG(ERROR) << "Split_flag_list do not have the same size as inputs shape, " << split_flag_list.size() << " : " << shapes.size(); @@ -665,14 +665,14 @@ void OperatorInfo::ComputeBatchSplitFlagList() { } // This is a common method for checking whether the generated stragegy has the correct number of devuces. -Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes& inputs_partitions, StrategyPtr* const sp) { +Status PrepareStrategyBase(int32_t stage_id, size_t dev_num, const Shapes &inputs_partitions, StrategyPtr *const sp) { if (sp == nullptr) { MS_LOG(ERROR) << "The strategy is null."; return FAILED; } int32_t product = 1; - for (auto& input_partition : inputs_partitions) { + for (auto &input_partition : inputs_partitions) { product *= std::accumulate(input_partition.begin(), input_partition.end(), 1, std::multiplies()); } if (!FULLY_USE_DEVICES) { @@ -694,7 +694,7 @@ std::shared_ptr>> OperatorInfo::GenerateBatchSt return GenerateBatchStrategiesBySplitFlag(inputs_shape_, split_flag_list_); } -void PrintStrategy(const StrategyPtr& strategy) { +void PrintStrategy(const StrategyPtr &strategy) { if (strategy == nullptr) { return; } @@ -716,8 +716,8 @@ void PrintStrategy(const StrategyPtr& strategy) { } // generate strategies for that each dimension of input0 and input1 is relevant, such as: ([a, b, c, d], [a, b, c, d]) -Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* const sp_vector) { +Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -740,7 +740,7 @@ Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& input return FAILED; } - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { sp->ExpandInputDimFromOneToTwo(); } @@ -749,8 +749,8 @@ Status GenerateStrategiesForTwoEqualInputs(int32_t stage_id, const Shapes& input // generate strategies for that input0 and input1 have relevant dimensions, and input0 needs to broadcast // such as: ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) -Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -770,7 +770,7 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs } // second, get the correct strategy for input0 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { std::vector tmp_strategy; Dimensions input0_strategy = sp->GetInputDim()[0]; size_t size_diff = inputs_shape[1].size() - inputs_shape[0].size(); @@ -798,8 +798,8 @@ Status GenerateStrategiesForBroadcastLeft(int32_t stage_id, const Shapes& inputs // generate strategies for that input0 and input1 have relevant dimensions, and input1 needs to broadcast // such as: ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) -Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -819,7 +819,7 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& input } // second, get the correct strategy for input1 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { std::vector tmp_strategy; tmp_strategy.push_back(sp->GetInputDim()[0]); // input0 @@ -848,8 +848,8 @@ Status GenerateStrategiesForBroadcastRight(int32_t stage_id, const Shapes& input // generate strategies for that input0 and input1 have same size, and input0 or input1 needs to broadcast // such as: ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -881,7 +881,7 @@ Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs } // step3: reset the strategy if the dimension is 1 - for (auto& sp : *sp_vector) { + for (auto &sp : *sp_vector) { Dimensions input0_strategy = sp->GetInputDim()[0]; Dimensions input1_strategy = sp->GetInputDim()[1]; for (size_t i = 0; i < inputs_shape[0].size(); ++i) { @@ -904,9 +904,9 @@ Status GenerateStrategiesForBroadcastBoth(int32_t stage_id, const Shapes& inputs // dimension is splittable. 'inputs_partitions' is the result of partitions. // NOTE: This implementation would partition all splittable dimensions in all inputs. Some operators requiring // specific dimensions in inputs have the identical partition should have individual implementation. -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -932,7 +932,7 @@ Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& in MS_LOG(DEBUG) << "The value of combined_splittable_inputs.size is: " << combined_splittable_inputs.size(); Shapes inputs_partitions; size_t global_index = 0; - for (auto& shape : inputs_shape) { + for (auto &shape : inputs_shape) { Shape tmp_partition; for (size_t j = 0; j < shape.size(); ++j) { tmp_partition.push_back(combined_partitions[global_index]); @@ -974,8 +974,8 @@ Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& in // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* const sp_vector) { +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *const sp_vector) { if (sp_vector == nullptr) { MS_LOG(ERROR) << "The sp_vector is null."; return FAILED; @@ -1025,7 +1025,7 @@ Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_sh return SUCCESS; } -Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr& strategy) { +Status OperatorInfo::SetCostUnderStrategyBase(const StrategyPtr &strategy) { if (InitForCostModel(strategy) == FAILED) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Initialization under the strategy failed."; @@ -1063,8 +1063,8 @@ int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { return is_output_parameter_involve_; } is_parameter_involve_ = is_parameter_; - const auto& prev_edges = this->GetAlivePrevEdges(); - for (auto& p_edge : prev_edges) { + const auto &prev_edges = this->GetAlivePrevEdges(); + for (auto &p_edge : prev_edges) { auto input_index = p_edge->next_op_input_index(); auto prev_op_para = p_edge->prev_operator()->ComputeOpAndPrevEdgeParameterInvolved(); if (input_index >= is_parameter_involve_.size()) { @@ -1090,7 +1090,7 @@ int OperatorInfo::ComputeOpAndPrevEdgeParameterInvolved() { return is_output_parameter_involve_; } -Status OperatorInfo::set_is_parameter(const std::vector& is_parameter) { +Status OperatorInfo::set_is_parameter(const std::vector &is_parameter) { if (is_parameter.size() != inputs_shape_.size()) { MS_LOG(ERROR) << "Is_parameter: " << is_parameter.size() << " do not have the same number of inputs_shape_: " << inputs_shape_.size(); @@ -1111,7 +1111,7 @@ Status OperatorInfo::CalculateMemoryCost() { operator_cost()->set_is_parameter_involve(is_parameter_involve_); operator_cost()->set_output_parameter_involve(is_output_parameter_involve_); // Set the memory cost in the 'strategy_cost_' - for (auto& swc : strategy_cost_) { + for (auto &swc : strategy_cost_) { auto mem_cost = operator_cost()->GetMemoryCost(swc->inputs_ptr, swc->outputs_ptr); swc->cost_list[0]->memory_with_reuse_ = mem_cost; } @@ -1119,7 +1119,7 @@ Status OperatorInfo::CalculateMemoryCost() { } Status OperatorInfo::CorrectMemoryCost(size_t input_index) { - for (auto& swc : strategy_cost_) { + for (auto &swc : strategy_cost_) { double parameter_mem_cost = ListProduct(swc->inputs_ptr[input_index].slice_shape()) * static_cast(operator_cost()->inputs_type_lengths()[input_index]); swc->cost_list[0]->memory_with_reuse_ -= parameter_mem_cost; @@ -1132,13 +1132,13 @@ Status OperatorInfo::CorrectMemoryCost(size_t input_index) { return SUCCESS; } -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape& dev_matrix_shape, const Shape& tensor_map) { +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map) { int32_t ret = -1; // The number of repetitions is equal to the number of all devices divided by the number of devices use for // tensor map. int32_t device_num = std::accumulate(dev_matrix_shape.begin(), dev_matrix_shape.end(), 1, std::multiplies()); - for (auto& element : tensor_map) { + for (auto &element : tensor_map) { // -1 means the corresponding dimension is not split. if (element == MAP_NONE) { continue; @@ -1211,8 +1211,8 @@ Status OperatorInfo::InferVirtualDivOps() { return SUCCESS; } -Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths) { +Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths) { if (input_lengths.size() != inputs_shape_.size()) { MS_LOG(ERROR) << "Input_lengths: " << input_lengths.size() << " do not have the same number of inputs shape: " << inputs_shape_.size(); @@ -1229,7 +1229,7 @@ Status OperatorInfo::SetInputAndOutputTypeLength(const std::vector& inpu return SUCCESS; } -Status OperatorInfo::set_outputs_type(const std::vector& outputs_type) { +Status OperatorInfo::set_outputs_type(const std::vector &outputs_type) { if (outputs_type.size() != outputs_shape_.size()) { MS_LOG(ERROR) << "Outputs type: " << outputs_type.size() << " do not have the same number of outputs shape: " << outputs_shape_.size(); @@ -1239,7 +1239,7 @@ Status OperatorInfo::set_outputs_type(const std::vector& outputs_type) return SUCCESS; } -void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr& stra, const CostPtr& cost) { +void OperatorInfo::BreakingTiesForPerferringDataParallel(const StrategyPtr &stra, const CostPtr &cost) { if (!stra->GetInputDim().empty() && !stra->GetInputDim()[0].empty()) { CheckGlobalDeviceManager(); auto total_device_num = g_device_manager->GetDeviceListByStageId(stra->GetInputStage()).size(); diff --git a/mindspore/ccsrc/parallel/ops_info/operator_info.h b/mindspore/ccsrc/parallel/ops_info/operator_info.h index 19e0eeeda1..347da7e573 100644 --- a/mindspore/ccsrc/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/parallel/ops_info/operator_info.h @@ -69,23 +69,23 @@ class OperatorInfo { virtual ~OperatorInfo() = default; - Status set_is_parameter(const std::vector& is_parameter); - Status SetInputAndOutputTypeLength(const std::vector& input_lengths, - const std::vector& output_lengths); + Status set_is_parameter(const std::vector &is_parameter); + Status SetInputAndOutputTypeLength(const std::vector &input_lengths, + const std::vector &output_lengths); // Set outputs dtype. // If only one output, outputs_type.size() is 1. // If output is tuple, outputs_type.size() is greater than 1. - Status set_outputs_type(const std::vector& outputs_type); - const std::vector& outputs_type() const { return outputs_type_; } - virtual Status Init(const StrategyPtr& strategy) = 0; - virtual Status InitForCostModel(const StrategyPtr& strategy) = 0; // only init the necessary parts + Status set_outputs_type(const std::vector &outputs_type); + const std::vector &outputs_type() const { return outputs_type_; } + virtual Status Init(const StrategyPtr &strategy) = 0; + virtual Status InitForCostModel(const StrategyPtr &strategy) = 0; // only init the necessary parts // Given the stage_id (which indicates the number of devices), // generate all strategies for this operator virtual Status GenerateStrategies(int32_t stage_id) = 0; - const OperatorCostPtr& operator_cost() const { return operator_cost_; } - void set_cost(const OperatorCostPtr& cost) { operator_cost_ = cost; } - virtual Status SetCostUnderStrategy(const StrategyPtr& strategy) = 0; + const OperatorCostPtr &operator_cost() const { return operator_cost_; } + void set_cost(const OperatorCostPtr &cost) { operator_cost_ = cost; } + virtual Status SetCostUnderStrategy(const StrategyPtr &strategy) = 0; virtual std::shared_ptr>> GenerateBatchStrategies(); virtual void ReComputeBatchSplitFlagList(); @@ -94,7 +94,7 @@ class OperatorInfo { double GetForwardMemoryCostFromCNode(); // This is a common method for setting operator cost for a given strategy, in which the validity of this strategy // is checked - Status SetCostUnderStrategyBase(const StrategyPtr& strategy); + Status SetCostUnderStrategyBase(const StrategyPtr &strategy); std::vector> GetStrategyCost() { return strategy_cost_; } // When the input of a operator contains WEIGHT or a output from other operators involving WEIGHT, then these input // should stay in memory until it is used in the backward phase, which is kept in memory at the end of forward phase. @@ -104,61 +104,61 @@ class OperatorInfo { ForwardOp forward_op() const { return forward_op_; } ForwardOp replace_op() const { return replace_op_; } OutPutInfoVector replace_op_info() const { return replace_op_info_; } - virtual ReplaceGraphPtr replace_graph(const CNodePtr&) { return replace_graph_; } + virtual ReplaceGraphPtr replace_graph(const CNodePtr &) { return replace_graph_; } MirrorOps mirror_ops() const { return mirror_ops_; } Ops sub_ops() const { return sub_ops_; } VirtualDivOp virtual_div_op() const { return virtual_div_op_; } Shape dev_matrix_shape() const { return dev_matrix_shape_; } std::vector inputs_tensor_info() const { return inputs_tensor_info_; } std::vector outputs_tensor_info() const { return outputs_tensor_info_; } - const std::string& name() const { return name_; } - void set_name(const std::string& name) { name_ = name; } + const std::string &name() const { return name_; } + void set_name(const std::string &name) { name_ = name; } RankList global_device_list() const { return global_device_list_; } - void AddSuccEdge(const std::shared_ptr& e) { succ_edges_.push_back(e); } - void AddPrevEdge(const std::shared_ptr& e) { prev_edges_.push_back(e); } + void AddSuccEdge(const std::shared_ptr &e) { succ_edges_.push_back(e); } + void AddPrevEdge(const std::shared_ptr &e) { prev_edges_.push_back(e); } std::vector> succ_edges() const { return succ_edges_; } std::vector> prev_edges() const { return prev_edges_; } std::vector> GetAliveSuccEdges(); std::vector> GetAlivePrevEdges(); - void ReplacePreEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplaceSuccEdge(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplacePreEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); - void ReplaceSuccEdges(const std::shared_ptr& op, const std::shared_ptr& new_edge); + void ReplacePreEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdge(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplacePreEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); + void ReplaceSuccEdges(const std::shared_ptr &op, const std::shared_ptr &new_edge); std::vector GetOutputTypeLengths() const { return operator_cost()->outputs_type_lengths(); } - void SetSelectedStrategyAndCost(const StrategyPtr& s_strategy, const CostPtr& cost) { + void SetSelectedStrategyAndCost(const StrategyPtr &s_strategy, const CostPtr &cost) { selected_strategy_ = s_strategy; selected_cost_ = cost; } StrategyPtr selected_strategy() const { return selected_strategy_; } CostPtr selected_cost() const { return selected_cost_; } - Status InitSelectedStrategy(const StrategyPtr& s_strategy) { return Init(s_strategy); } - void set_input_value(const std::vector& input_value) { input_value_ = input_value; } - void set_outputs_dtype(const TypePtr& dtype) { outputs_dtype_ = dtype; } - void set_cnode(const CNodePtr& cnode) { cnode_ = cnode; } + Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } + void set_input_value(const std::vector &input_value) { input_value_ = input_value; } + void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } + void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } bool is_alive() const { return is_alive_; } void SetNotAlive() { is_alive_ = false; } StrategyPtr strategy() const { return strategy_; } - void set_strategy(const StrategyPtr& strategy) { strategy_ = strategy; } + void set_strategy(const StrategyPtr &strategy) { strategy_ = strategy; } void set_refkey_parameter_name(std::string p_name) { refkey_parameter_name_ = std::move(p_name); } - const std::string& refkey_parameter_name() const { return refkey_parameter_name_; } + const std::string &refkey_parameter_name() const { return refkey_parameter_name_; } // When the output of a Parameter (require_grad) being used by multiple operators, the Parameter's cost is calculated // multiple times. This method is to correct this, and makes the cost is calulated only once. Status CorrectMemoryCost(size_t input_index); int is_output_parameter_involve() const { return is_output_parameter_involve_; } int used_devices() const { return used_devices_; } // needed by rec_parser - void set_type(const std::string& type) { type_ = type; } - const std::string& type() const { return type_; } - void set_cnode_name(const std::string& cnode_name) { cnode_name_ = cnode_name; } - const std::string& cnode_name() const { return cnode_name_; } - const std::unordered_map& attrs() const { return attrs_; } + void set_type(const std::string &type) { type_ = type; } + const std::string &type() const { return type_; } + void set_cnode_name(const std::string &cnode_name) { cnode_name_ = cnode_name; } + const std::string &cnode_name() const { return cnode_name_; } + const std::unordered_map &attrs() const { return attrs_; } protected: // needed by rec_parser std::string type_; std::string cnode_name_; - virtual Status CheckStrategy(const StrategyPtr& strategy) = 0; + virtual Status CheckStrategy(const StrategyPtr &strategy) = 0; virtual Status InferTensorMap() = 0; virtual Status InferForwardCommunication() = 0; virtual Status InferMirrorOps() = 0; @@ -167,14 +167,14 @@ class OperatorInfo { virtual Status InferDevMatrixShape() = 0; void SetDeviceListByStrategy(); void SetRepeatedCalcDevMatrix(); - Status CreateGroupByTensorMap(const Shape& tensor_map, std::vector* group); - Status CreateGroupByDim(size_t axis, std::vector* group); + Status CreateGroupByTensorMap(const Shape &tensor_map, std::vector *group); + Status CreateGroupByDim(size_t axis, std::vector *group); Status InferAttrs(); void ResetQueueMember(); - Status InitWithAutoRepeatCalc(const StrategyPtr& strategy); - Status InitWithManualRepeatCalc(const StrategyPtr& strategy); - Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr& strategy); - Status InitForCostModelWithManualRepeatCalc(const StrategyPtr& strategy); + Status InitWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitWithManualRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithAutoRepeatCalc(const StrategyPtr &strategy); + Status InitForCostModelWithManualRepeatCalc(const StrategyPtr &strategy); Status InferRepeatedCalcInfo(); Status InferVirtualDivOps(); @@ -182,9 +182,9 @@ class OperatorInfo { // The tensor map of Outputs[0] is used by default. If there are multiple outputs, need to identify which output // is used for grad and overload the function. If the output is a scalar, need to override the function too. virtual Status InferAsLossDivisor(); - Status InferSliceShape(const Strategys& inputs_strategy, const Strategys& outputs_strategy, - Shapes* inputs_slice_shape, Shapes* outputs_slice_shape); - void BreakingTiesForPerferringDataParallel(const StrategyPtr&, const CostPtr&); + Status InferSliceShape(const Strategys &inputs_strategy, const Strategys &outputs_strategy, + Shapes *inputs_slice_shape, Shapes *outputs_slice_shape); + void BreakingTiesForPerferringDataParallel(const StrategyPtr &, const CostPtr &); std::string name_; Shapes inputs_shape_; @@ -242,29 +242,29 @@ class OperatorInfo { std::vector outputs_type_; }; -Shape GetSliceShape(const Shape& tensor_shape, const Dimensions& strategy); -Status CheckStrategyValue(const StrategyPtr& strategy, const Shapes& inputs_shape, bool); +Shape GetSliceShape(const Shape &tensor_shape, const Dimensions &strategy); +Status CheckStrategyValue(const StrategyPtr &strategy, const Shapes &inputs_shape, bool); Operator CreateVirtualDivOp(int32_t div_num); -Operator CreateAllReduceOp(const std::string& reduce_op, const std::string& group); -Operator CreateGetTensorSliceOp(const TensorLayout& tensor_layout); -OperatorVector CreateMirrorOps(const std::string& group_name, size_t dev_num); -int32_t ComputeRepeatDeviceNumByTensorMap(const Shape& dev_matrix_shape, const Shape& tensor_map); +Operator CreateAllReduceOp(const std::string &reduce_op, const std::string &group); +Operator CreateGetTensorSliceOp(const TensorLayout &tensor_layout); +OperatorVector CreateMirrorOps(const std::string &group_name, size_t dev_num); +int32_t ComputeRepeatDeviceNumByTensorMap(const Shape &dev_matrix_shape, const Shape &tensor_map); std::shared_ptr>> GenerateBatchStrategiesBySplitFlag( - const Shapes& shapes, const std::vector& split_flag_list); + const Shapes &shapes, const std::vector &split_flag_list); -void PrintStrategy(const StrategyPtr& strategy); +void PrintStrategy(const StrategyPtr &strategy); // generate strategies for that all inputs' dimensions are independent, such as: ([a, b, c, d]) -Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes& inputs_shape, - const Shapes& splittable_inputs, std::vector* sp_vector); +Status GenerateStrategiesForIndependentInputs(int32_t stage_id, const Shapes &inputs_shape, + const Shapes &splittable_inputs, std::vector *sp_vector); // generate strategies for that have two inputs, and input0 or input1 maybe broadcast, // and the corresponding dimensions that are not broadcast are all relevant dimensions // such as: ([a, b, c, d], [a, b, c, d]) or ([b, c, d], [a, b, c, d]) or ([1, c, d], [a, b, c, d]) // or ([a, b, c, d], [b, c, d]) or ([a, b, c, d], [1, c, d]) // or ([a, 1], [1, b]) or ([a, b, c, d], [1, b, c, d]) or ([a, b, c, 1], [1, b, c, d]) -Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes& inputs_shape, const Shapes& splittable_inputs, - std::vector* sp_vector); +Status GenerateStrategiesWithBroadcast(int32_t stage_id, const Shapes &inputs_shape, const Shapes &splittable_inputs, + std::vector *sp_vector); -Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc index a4d601dbe9..fed361616b 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc @@ -34,7 +34,7 @@ namespace parallel { * w: Float Tensor, w > 0: there is only two shapes are legitimate: 1, or the number of channels at input. * the strategy of w should equal to the channel dimension of strategy of A */ -Status PReLUInfo::CheckStrategy(const StrategyPtr& strategy) { +Status PReLUInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -119,7 +119,7 @@ Dimensions PReLUInfo::GetOutputStrategy() { return output_strategy; } -Status PReLUInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status PReLUInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if (inputs_layout == nullptr || outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -181,7 +181,7 @@ Status PReLUInfo::GetAttrs() { return SUCCESS; } -Status PReLUInfo::Init(const StrategyPtr& strategy) { +Status PReLUInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -190,7 +190,7 @@ Status PReLUInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status PReLUInfo::InitForCostModel(const StrategyPtr& strategy) { +Status PReLUInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -224,7 +224,7 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; @@ -234,7 +234,7 @@ Status PReLUInfo::GenerateStrategies(int32_t stage_id) { return SUCCESS; } -Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status PReLUInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.h b/mindspore/ccsrc/parallel/ops_info/prelu_info.h index 396407c1ee..28e149fad7 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.h +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.h @@ -33,24 +33,24 @@ namespace parallel { */ class PReLUInfo : public OperatorInfo { public: - PReLUInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + PReLUInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(true)) {} ~PReLUInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Dimensions GetOutputStrategy(); diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc index 4cb81ee769..d6e1c277ef 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status ReshapeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status ReshapeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -137,7 +137,7 @@ Status ReshapeInfo::GetParameterInput() { return FAILED; } - for (auto& element : elements) { + for (auto &element : elements) { MS_EXCEPTION_IF_NULL(element); if (element->isa()) { int32_t axis = element->cast()->value(); @@ -216,7 +216,7 @@ Strategys ReshapeInfo::GetOutputsStrategy() { return outputs_strategy; } -Status ReshapeInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status ReshapeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if (inputs_layout == nullptr || outputs_layout == nullptr) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -302,7 +302,7 @@ void ReshapeInfo::InferTensorInfoByLayout() { */ Status ReshapeInfo::GetAttrs() { return GetParameterInput(); } -void ReshapeInfo::device_number(const StrategyPtr& strategy) { +void ReshapeInfo::device_number(const StrategyPtr &strategy) { int32_t stage = 0; if (strategy != nullptr) { stage = strategy->GetInputStage(); @@ -313,7 +313,7 @@ void ReshapeInfo::device_number(const StrategyPtr& strategy) { MS_ASSERT(dev_num_ > 0); } -Status ReshapeInfo::InferDefaultLayout(const Shape& shape, TensorLayout* const layout) { +Status ReshapeInfo::InferDefaultLayout(const Shape &shape, TensorLayout *const layout) { std::vector tensor_map_index; for (size_t i = 0; i < shape.size(); i++) { tensor_map_index.push_back(MAP_NONE); @@ -326,7 +326,7 @@ Status ReshapeInfo::InferDefaultLayout(const Shape& shape, TensorLayout* const l return Status::SUCCESS; } -Status ReshapeInfo::Init(const StrategyPtr& strategy) { +Status ReshapeInfo::Init(const StrategyPtr &strategy) { ResetQueueMember(); device_number(strategy); if (strategy) { @@ -375,7 +375,7 @@ Status ReshapeInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status ReshapeInfo::InitForCostModel(const StrategyPtr& strategy) { +Status ReshapeInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -389,7 +389,7 @@ Status ReshapeInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status ReshapeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; @@ -423,7 +423,7 @@ Status ReshapeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/reshape_info.h b/mindspore/ccsrc/parallel/ops_info/reshape_info.h index 3864d2b93d..99ee014175 100644 --- a/mindspore/ccsrc/parallel/ops_info/reshape_info.h +++ b/mindspore/ccsrc/parallel/ops_info/reshape_info.h @@ -34,34 +34,34 @@ namespace parallel { */ class ReshapeInfo : public OperatorInfo { public: - ReshapeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + ReshapeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)), dev_num_(0), input_layout_set_flag_(false), output_layout_set_flag_(false) {} ~ReshapeInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - void SetInputLayout(const TensorLayout& input_layout) { + Status Init(const StrategyPtr &strategy) override; + void SetInputLayout(const TensorLayout &input_layout) { input_layout_ = input_layout; input_layout_set_flag_ = true; } - void SetOutputLayout(const TensorLayout& output_layout) { + void SetOutputLayout(const TensorLayout &output_layout) { output_layout_ = output_layout; output_layout_set_flag_ = true; } - Status InitForCostModel(const StrategyPtr& strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorMap() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Strategys GetOutputsStrategy(); @@ -69,8 +69,8 @@ class ReshapeInfo : public OperatorInfo { Status GetParameterInput(); Status ComputeReplaceOp(); void InferTensorInfoByLayout(); - void device_number(const StrategyPtr& strategy); - Status InferDefaultLayout(const Shape& shape, TensorLayout* const layout); + void device_number(const StrategyPtr &strategy); + Status InferDefaultLayout(const Shape &shape, TensorLayout *const layout); int32_t dev_num_; std::vector parameter_input_v_; diff --git a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h index 3682fe334f..f7895d0511 100644 --- a/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h +++ b/mindspore/ccsrc/parallel/ops_info/tmp_identity_info.h @@ -32,19 +32,19 @@ class TmpIdentityInfo : public OperatorInfo { // consider this parameter tensor as TmpIdentityInfo operator. TmpIdentityInfo operator tasks as input a tensor, // and outputs the same tensor. After the transformation, subsequent operators can share the output tensor. public: - TmpIdentityInfo(const Shapes& inputs_shape, const Shapes& outputs_shape, const PrimitiveAttrs& attrs, - const std::string& name = IDENTITY_INFO) + TmpIdentityInfo(const Shapes &inputs_shape, const Shapes &outputs_shape, const PrimitiveAttrs &attrs, + const std::string &name = IDENTITY_INFO) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TmpIdentityInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status GetAttrs() override { return SUCCESS; } Status InferMirrorOps() override { return SUCCESS; } Status InferForwardCommunication() override { return SUCCESS; } diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc index 84333a1337..49bbae0cb4 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status TransposeInfo::CheckStrategy(const StrategyPtr& strategy) { +Status TransposeInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -43,7 +43,7 @@ Status TransposeInfo::CheckStrategy(const StrategyPtr& strategy) { Status TransposeInfo::InferDevMatrixShape() { std::vector stra = strategy_->GetInputDim(); input_strategy_ = stra.at(0); - for (auto& iter : input_strategy_) { + for (auto &iter : input_strategy_) { dev_matrix_shape_.push_back(iter); } return SUCCESS; @@ -77,7 +77,7 @@ Status TransposeInfo::ComputeAxis() { return FAILED; } axis_v_.clear(); - for (auto& element : elements) { + for (auto &element : elements) { MS_EXCEPTION_IF_NULL(element); if (element->isa()) { int32_t axis = element->cast()->value(); @@ -130,7 +130,7 @@ Strategys TransposeInfo::GetOutputsStrategy() { return outputs_strategy; } -Status TransposeInfo::InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout) { +Status TransposeInfo::InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout) { if ((inputs_layout == nullptr) || (outputs_layout == nullptr)) { MS_LOG(ERROR) << name_ << ": InferTensorLayout: the layout is null."; return FAILED; @@ -179,7 +179,7 @@ Status TransposeInfo::InferTensorInfo() { // compute axis_v_ during this method Status TransposeInfo::GetAttrs() { return ComputeAxis(); } -Status TransposeInfo::Init(const StrategyPtr& strategy) { +Status TransposeInfo::Init(const StrategyPtr &strategy) { if (InitWithAutoRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -188,7 +188,7 @@ Status TransposeInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status TransposeInfo::InitForCostModel(const StrategyPtr& strategy) { +Status TransposeInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -202,7 +202,7 @@ Status TransposeInfo::InitForCostModel(const StrategyPtr& strategy) { return SUCCESS; } -Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr& strategy) { +Status TransposeInfo::SetCostUnderStrategy(const mindspore::parallel::StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(ERROR) << name_ << ": Set cost under strategy failed."; @@ -234,7 +234,7 @@ Status TransposeInfo::GenerateStrategies(int32_t stage_id) { return FAILED; } size_t success = 0; - for (auto& sp : sp_vector) { + for (auto &sp : sp_vector) { if (SetCostUnderStrategy(sp) == SUCCESS) { success++; MS_LOG(INFO) << name_ << ": Successfully generated " << success << "strategy."; diff --git a/mindspore/ccsrc/parallel/ops_info/transpose_info.h b/mindspore/ccsrc/parallel/ops_info/transpose_info.h index e4e2b90b7b..50b76bde65 100644 --- a/mindspore/ccsrc/parallel/ops_info/transpose_info.h +++ b/mindspore/ccsrc/parallel/ops_info/transpose_info.h @@ -33,23 +33,23 @@ namespace parallel { */ class TransposeInfo : public OperatorInfo { public: - TransposeInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + TransposeInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~TransposeInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; Status InferDevMatrixShape() override; Status InferTensorMap() override; - Status InferTensorLayout(TensorLayouts* inputs_layout, TensorLayouts* outputs_layout); + Status InferTensorLayout(TensorLayouts *inputs_layout, TensorLayouts *outputs_layout); Status GetAttrs() override; Strategys GetOutputsStrategy(); diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc index cd3b40315c..4b695ba62d 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::CheckStrategy(const StrategyPtr &strategy) { if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid strategy."; @@ -171,7 +171,7 @@ Status VirtualDatasetInfo::InferTensorInfo() { Status VirtualDatasetInfo::GetAttrs() { return SUCCESS; } -Status VirtualDatasetInfo::Init(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::Init(const StrategyPtr &strategy) { if (InitWithManualRepeatCalc(strategy) != SUCCESS) { MS_LOG(ERROR) << name_ << ": Init failed."; return FAILED; @@ -179,7 +179,7 @@ Status VirtualDatasetInfo::Init(const StrategyPtr& strategy) { return SUCCESS; } -Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::InitForCostModel(const StrategyPtr &strategy) { if (InitForCostModelWithManualRepeatCalc(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Init for cost model failed."; @@ -199,7 +199,7 @@ void VirtualDatasetInfo::ReComputeBatchSplitFlagList() { } } -Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr& strategy) { +Status VirtualDatasetInfo::SetCostUnderStrategy(const StrategyPtr &strategy) { if (SetCostUnderStrategyBase(strategy) != SUCCESS) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Set cost under strategy failed."; @@ -223,7 +223,7 @@ Status VirtualDatasetInfo::GenerateStrategies(int32_t stage_id) { size_t total_dev_num = g_device_manager->GetDeviceListByStageId(stage_id).size(); StrategyPtr sp; std::vector strategy; - for (auto& shape : inputs_shape_) { + for (auto &shape : inputs_shape_) { Shape temp; temp.emplace_back(SizeToInt(total_dev_num)); (void)temp.insert(temp.end(), shape.size() - 1, 1); diff --git a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h index 398bae3585..312ac7a6a4 100644 --- a/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h +++ b/mindspore/ccsrc/parallel/ops_info/virtual_dataset_info.h @@ -30,19 +30,19 @@ namespace mindspore { namespace parallel { class VirtualDatasetInfo : public OperatorInfo { public: - VirtualDatasetInfo(const std::string& name, const Shapes& inputs_shape, const Shapes& outputs_shape, - const PrimitiveAttrs& attrs) + VirtualDatasetInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared(false)) {} ~VirtualDatasetInfo() override = default; - Status Init(const StrategyPtr& strategy) override; - Status InitForCostModel(const StrategyPtr& strategy) override; + Status Init(const StrategyPtr &strategy) override; + Status InitForCostModel(const StrategyPtr &strategy) override; Status GenerateStrategies(int32_t stage_id) override; - Status SetCostUnderStrategy(const StrategyPtr& strategy) override; + Status SetCostUnderStrategy(const StrategyPtr &strategy) override; void ReComputeBatchSplitFlagList() override; protected: - Status CheckStrategy(const StrategyPtr& strategy) override; + Status CheckStrategy(const StrategyPtr &strategy) override; Status InferMirrorOps() override; Status InferForwardCommunication() override; Status InferTensorInfo() override; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index bcd4dc3763..d1390db899 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -76,7 +76,7 @@ void SetCommunicationOpGroupLabel(std::vector new_node_input) { } } -std::vector CreateInput(const Operator& op, const AnfNodePtr& node, const std::string& instance_name) { +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { MS_EXCEPTION_IF_NULL(node); OperatorArgs arg_forward = op.second; ValuePtr pyop_instance = CreatOpInstance(arg_forward.first, op.first, instance_name); @@ -85,7 +85,7 @@ std::vector CreateInput(const Operator& op, const AnfNodePtr& node, std::vector new_node_input = {NewValueNode(pyop_instance), node}; if (!params.empty()) { - for (auto& param : params) { + for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); MS_EXCEPTION_IF_NULL(val); int32_t position = param.second; @@ -98,8 +98,8 @@ std::vector CreateInput(const Operator& op, const AnfNodePtr& node, return new_node_input; } -void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const AnfNodePtr& pre_node, - const FuncGraphPtr& func_graph, const std::string& instance_name) { +void InsertNode(const Operator &op, const CNodePtr &node, size_t index, const AnfNodePtr &pre_node, + const FuncGraphPtr &func_graph, const std::string &instance_name) { // insert new node before the node FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -121,7 +121,7 @@ void InsertNode(const Operator& op, const CNodePtr& node, size_t index, const An manager->SetEdge(node, SizeToInt(index), new_node); } -std::string CreateInstanceName(const CNodePtr& node, size_t index) { +std::string CreateInstanceName(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); if (!IsValueNode(node->input(0))) { MS_LOG(EXCEPTION) << "CreateInstanceName: " << node->ToString() << " doesn't have primitive"; @@ -132,7 +132,7 @@ std::string CreateInstanceName(const CNodePtr& node, size_t index) { return instance_name; } -void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); // step1:get graph manager distribute_operator FuncGraphPtr func_graph = node->func_graph(); @@ -141,7 +141,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { MS_EXCEPTION_IF_NULL(manager); auto uses_set = manager->node_users()[node]; CNodePtr node_to_insert = node; - for (auto& uses_pair : uses_set) { + for (auto &uses_pair : uses_set) { auto uses_cnode = uses_pair.first->cast(); MS_EXCEPTION_IF_NULL(uses_cnode); if (!IsValueNode(uses_cnode->input(0))) { @@ -175,7 +175,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node) { } } -CNodePtr InsertMakeTuple(const AnfNodePtr& prev, uint32_t num, const FuncGraphPtr& func_graph) { +CNodePtr InsertMakeTuple(const AnfNodePtr &prev, uint32_t num, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(prev); MS_EXCEPTION_IF_NULL(func_graph); std::vector make_tuple_inputs; @@ -195,8 +195,8 @@ CNodePtr InsertMakeTuple(const AnfNodePtr& prev, uint32_t num, const FuncGraphPt return make_tuple; } -void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_ptr, const CNodePtr& node, - const FuncGraphPtr& func_graph, int pos, const CNodePtr& pre_node) { +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(pre_node); MS_EXCEPTION_IF_NULL(func_graph); @@ -226,8 +226,8 @@ void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_p } } -void InsertGetTensorSliceOp(const Operator& op, const CNodePtr& node, const FuncGraphPtr& func_graph, int pos, - const std::string& instance_name) { +void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const FuncGraphPtr &func_graph, int pos, + const std::string &instance_name) { if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: the graph is null, the instance name is " << instance_name; } @@ -244,8 +244,8 @@ void InsertGetTensorSliceOp(const Operator& op, const CNodePtr& node, const Func InsertNode(op, node, IntToSize(pos), pre_node, func_graph, instance_name); } -TensorLayout GetTensorInLayout(const CNodePtr& middle_node, const PrimitivePtr& middle_prim, - const OperatorInfoPtr& distribute_operator) { +TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &middle_prim, + const OperatorInfoPtr &distribute_operator) { TensorInfo tensorinfo_in; if (middle_prim->name() == TUPLE_GETITEM) { auto value_node = middle_node->input(2)->cast(); @@ -265,7 +265,7 @@ TensorLayout GetTensorInLayout(const CNodePtr& middle_node, const PrimitivePtr& return tensorinfo_in.tensor_layout(); } -OperatorInfoPtr GetDistributeOperator(const CNodePtr& node) { +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!IsParallelCareNode(node)) { return nullptr; @@ -277,9 +277,9 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr& node) { return distribute_operator; } -void Redistribution(const std::pair& node_pair, const OperatorInfoPtr& distribute_operator, - const CNodePtr& middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr& pre_node) { +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node) { FuncGraphPtr func_graph = middle_node->func_graph(); if (func_graph == nullptr) { MS_LOG(EXCEPTION) << "Redistribution:get graph failed"; @@ -333,13 +333,13 @@ bool StrategyFound(std::unordered_map attrs) { return !((iter == attrs.end()) || (iter->second->type_name() == NONE)); } -bool IsCommunicationOp(const PrimitivePtr& prim) { +bool IsCommunicationOp(const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(prim); return (COMMUNICATION_OPS.find(prim->name()) != COMMUNICATION_OPS.end()); } -bool FindCommunicationOp(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +bool FindCommunicationOp(const std::vector &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -364,7 +364,7 @@ bool FindCommunicationOp(const std::vector& all_nodes) { return false; } -bool IsParallelCareNode(const CNodePtr& cnode) { +bool IsParallelCareNode(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); ValueNodePtr prim_node = cnode->input(0)->cast(); if (prim_node == nullptr) { @@ -389,8 +389,8 @@ bool IsParallelCareNode(const CNodePtr& cnode) { return cnode->in_forward_flag(); } -void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_operator, const CNodePtr& insert_node, - const TensorRedistribution& tensor_redistribution, const CNodePtr& pre_node) { +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node) { MS_EXCEPTION_IF_NULL(node->func_graph()); FuncGraphManagerPtr manager = node->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); @@ -406,7 +406,7 @@ void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_ insert_node_new = insert_node; } MS_EXCEPTION_IF_NULL(insert_node_new); - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(use_cnode); if (!IsValueNode(use_cnode->input(0))) { @@ -429,7 +429,7 @@ void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_ } } -void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { +void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(next_node); OperatorInfoPtr op_info = next_node->operator_info(); @@ -474,11 +474,11 @@ void SplitTensor(const AnfNodePtr& node, const CNodePtr& next_node, int index) { } } -void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) { +void StepSplitTensor(const AnfNodePtr &node, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_cnode = node_pair.first->cast(); if (use_cnode == nullptr || !IsValueNode(use_cnode->input(0))) { continue; @@ -496,8 +496,8 @@ void StepSplitTensor(const AnfNodePtr& node, const FuncGraphManagerPtr& manager) } } -std::vector ReplaceOpInput(const Operator& replace_op, const std::string& instance_name, - const CNodePtr& node) { +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node) { OperatorArgs arg_replace_op = replace_op.second; ValuePtr pyop_instance = CreatOpInstance(arg_replace_op.first, replace_op.first, instance_name); if (pyop_instance == nullptr) { @@ -518,7 +518,7 @@ std::vector ReplaceOpInput(const Operator& replace_op, const std::st if (first_position == 1) { replace_input.pop_back(); } - for (auto& param : params) { + for (auto ¶m : params) { AnfNodePtr val = NewValueNode(param.first.second); if (val == nullptr) { MS_LOG(EXCEPTION) << "Failure:val is nullptr"; @@ -531,7 +531,7 @@ std::vector ReplaceOpInput(const Operator& replace_op, const std::st return replace_input; } -void ReplaceOneOp(const Operator& replace_op, const CNodePtr& node) { +void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = func_graph->manager(); @@ -551,7 +551,7 @@ void ReplaceOneOp(const Operator& replace_op, const CNodePtr& node) { (void)manager->Replace(node, replace_node); } -void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node) { +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { // step1:get graph manager distribute_operator OperatorInfoPtr distribute_operator = node->operator_info(); if (distribute_operator == nullptr) { @@ -599,15 +599,15 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node) { MS_LOG(INFO) << "Insert ReplaceOp success for " << distribute_operator->name(); } -bool IsSomePrimitive(const CNodePtr& cnode, const std::string& name) { +bool IsSomePrimitive(const CNodePtr &cnode, const std::string &name) { ValueNodePtr anf_node = cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(anf_node); PrimitivePtr prim = anf_node->value()->cast(); return (prim->name() == name); } -void StepReplaceGraph(const std::shared_ptr, AnfNodePtr>>& replace_graph, - const CNodePtr& node) { +void StepReplaceGraph(const std::shared_ptr, AnfNodePtr>> &replace_graph, + const CNodePtr &node) { MS_EXCEPTION_IF_NULL(replace_graph); MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(replace_graph->second); @@ -627,7 +627,7 @@ void StepReplaceGraph(const std::shared_ptr, A if (replace_graph->first.size() != 2) { MS_LOG(EXCEPTION) << "Failure:replace_graph->first.size() must be 2 for OneHot Primitive!"; } - for (auto& replace_input : replace_graph->first) { + for (auto &replace_input : replace_graph->first) { MS_EXCEPTION_IF_NULL(replace_input); manager->SetEdge(replace_input, 1, pre_node); CNodePtr replace_input_cnode = replace_input->cast(); @@ -645,7 +645,7 @@ void StepReplaceGraph(const std::shared_ptr, A replace_output_cnode->set_in_forward_flag(true); // mark this new cnode is forward node } -int32_t GetTupleGetItemIndex(const CNodePtr& cnode) { +int32_t GetTupleGetItemIndex(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() != 3) { MS_LOG(EXCEPTION) << cnode->ToString() << " size( " << cnode->inputs().size() << " ) is not 3"; @@ -666,7 +666,7 @@ int32_t GetTupleGetItemIndex(const CNodePtr& cnode) { // Judge whether the node is a loss, and if there are multiple outputs, // get which output is a grad according to the tuple getitem. // Currently, it is not supported that the sens is a tuple. -LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { +LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { MS_EXCEPTION_IF_NULL(loss_node); FuncGraphPtr sub_graph = loss_node->func_graph(); MS_EXCEPTION_IF_NULL(sub_graph); @@ -718,7 +718,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { MS_LOG(EXCEPTION) << "Invalid loss"; } -void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node) { +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); FuncGraphPtr func_graph = node->func_graph(); @@ -742,7 +742,7 @@ void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node } } -std::pair FindParameter(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { if (!node->isa() && !node->isa() && !node->isa()) { return std::make_pair(nullptr, false); } else if (node->isa()) { @@ -790,7 +790,7 @@ std::pair FindParameter(const AnfNodePtr& node, const FuncGrap return std::make_pair(nullptr, false); } -std::pair FindCNode(const AnfNodePtr& anode, const std::string& name, const FuncGraphPtr& func_graph) { +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(anode); MS_EXCEPTION_IF_NULL(anode->func_graph()); FuncGraphManagerPtr manager = anode->func_graph()->manager(); @@ -798,7 +798,7 @@ std::pair FindCNode(const AnfNodePtr& anode, const std::string& AnfNodeIndexSet node_set = manager->node_users()[anode]; bool result = false; CNodePtr cnode_return = nullptr; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; @@ -820,7 +820,7 @@ std::pair FindCNode(const AnfNodePtr& anode, const std::string& return std::make_pair(result, cnode_return); } -bool IsCastBeforMirror(const CNodePtr& node, size_t index) { +bool IsCastBeforMirror(const CNodePtr &node, size_t index) { // only if cast_before_mirror is true, pre node is cast and type is not float32 return true if (!ParallelContext::GetInstance()->cast_before_mirror()) { return false; @@ -850,7 +850,7 @@ bool IsCastBeforMirror(const CNodePtr& node, size_t index) { return (type_id != kNumberTypeFloat32); } -void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); size_t node_size = node->inputs().size(); FuncGraphPtr func_graph = node->func_graph(); @@ -887,7 +887,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { } std::string instance_name = MIRROR_OP; if (IsCastBeforMirror(node, index)) { - for (auto& op : backward_op) { + for (auto &op : backward_op) { // insert new node before the node CNodePtr cnode = node->input(index)->cast(); MS_EXCEPTION_IF_NULL(cnode); @@ -895,7 +895,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { InsertNode(op, cnode, size_t(1), pre_node, func_graph, instance_name); } } else { - for (auto& op : backward_op) { + for (auto &op : backward_op) { AnfNodePtr pre_node = node->input(index); InsertNode(op, node, index, pre_node, func_graph, instance_name); } @@ -903,7 +903,7 @@ void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node) { } } -void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNodePtr& node, bool is_loss_node) { +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(node); MirrorOps mirror_ops = distribute_operator->mirror_ops(); @@ -920,7 +920,7 @@ void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNo } } -std::string GetDisOpName(const std::string& prim_name) { +std::string GetDisOpName(const std::string &prim_name) { std::string op_name = prim_name; if (!prim_name.empty() && (prim_name[0] == '_')) { op_name = prim_name.substr(1); @@ -928,8 +928,8 @@ std::string GetDisOpName(const std::string& prim_name) { return op_name + "Info"; } -OperatorInfoPtr OperatorInstanceByName(const std::string& name, const PrimitiveAttrs& attrs, - const std::vector& shape_list) { +OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { if (shape_list.size() != 2) { MS_LOG(ERROR) << "The size of shape list is not 2"; return nullptr; @@ -951,8 +951,8 @@ OperatorInfoPtr OperatorInstanceByName(const std::string& name, const PrimitiveA return operator_; } -OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, - const std::vector& shape_list) { +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list) { MS_EXCEPTION_IF_NULL(prim); OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list); if (operator_ == nullptr) { @@ -963,7 +963,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& return operator_; } -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, std::vector shape_list) { OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); for (size_t i = 0; i < shape_list[0].size(); ++i) { @@ -992,7 +992,7 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { std::vector value_vector = value_tuple->value(); (void)std::transform(value_vector.begin(), value_vector.end(), std::back_inserter(dim), - [](const ValuePtr& value) { return static_cast(GetValue(value)); }); + [](const ValuePtr &value) { return static_cast(GetValue(value)); }); strategy.push_back(dim); } else { MS_LOG(EXCEPTION) << "Failure:Strategy's format is wrong! Need ValueSequeue"; @@ -1007,7 +1007,7 @@ StrategyPtr ExtractStrategy(std::unordered_map attrs) { return strategyPtr; } -Shapes GetNodeShape(const AnfNodePtr& node) { +Shapes GetNodeShape(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shapes; BaseShapePtr base_shape_ptr = node->Shape(); @@ -1039,7 +1039,7 @@ Shapes GetNodeShape(const AnfNodePtr& node) { auto tuple_shape_ptr = dyn_cast(base_shape_ptr); if (tuple_shape_ptr != nullptr) { auto tuple_shape = tuple_shape_ptr->shape(); - for (auto& shape : tuple_shape) { + for (auto &shape : tuple_shape) { auto each_shape = dyn_cast(shape); MS_EXCEPTION_IF_NULL(each_shape); shapes.push_back(each_shape->shape()); @@ -1052,7 +1052,7 @@ Shapes GetNodeShape(const AnfNodePtr& node) { return shapes; } -std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(func_graph); std::vector parameters; @@ -1075,7 +1075,7 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr root_g = roots.back(); MS_EXCEPTION_IF_NULL(root_g); - for (auto& param_node : root_g->parameters()) { + for (auto ¶m_node : root_g->parameters()) { auto param = param_node->cast(); if (param && (name == param->name())) { parameters.push_back(param_node); @@ -1088,7 +1088,7 @@ std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const return parameters; } -Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph) { +Shapes GetRefKeyNodeShape(const AnfNodePtr &node, const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(func_graph); @@ -1107,7 +1107,7 @@ Shapes GetRefKeyNodeShape(const AnfNodePtr& node, const FuncGraphPtr& func_graph return input_shapes; } -std::vector ExtractShape(const CNodePtr& node) { +std::vector ExtractShape(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); Shapes shape_inputs, shape_outputs; std::vector shape_all; @@ -1145,14 +1145,14 @@ std::vector ExtractShape(const CNodePtr& node) { return shape_all; } -std::pair FindParallelCareNode(const AnfNodePtr& node) { +std::pair FindParallelCareNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); FuncGraphPtr func_graph = node->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); FuncGraphManagerPtr manager = func_graph->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[node]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); if (!IsValueNode(cnode->input(0))) { @@ -1174,7 +1174,7 @@ std::pair FindParallelCareNode(const AnfNodePtr& node) { return std::make_pair(nullptr, 0); } -std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNodePtr& parameter) { +std::pair FindSubGraph(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(parameter); FuncGraphManagerPtr manager = graph->manager(); @@ -1184,7 +1184,7 @@ std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNode return prim_anf_node_pair; } else { AnfNodeIndexSet param_sub_set = manager->node_users()[parameter]; - for (auto& param_pair : param_sub_set) { + for (auto ¶m_pair : param_sub_set) { CNodePtr graph_cnode = param_pair.first->cast(); if ((graph_cnode == nullptr) || !graph_cnode->input(0)->isa()) { continue; @@ -1208,7 +1208,7 @@ std::pair FindSubGraph(const FuncGraphPtr& graph, const AnfNode return std::make_pair(nullptr, 0); } -void SetParallelShape(const AnfNodePtr& parameter, const std::pair& res) { +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res) { MS_EXCEPTION_IF_NULL(parameter); AbstractBasePtr abstract = parameter->abstract(); MS_EXCEPTION_IF_NULL(abstract); @@ -1237,10 +1237,10 @@ void SetParallelShape(const AnfNodePtr& parameter, const std::pairset_tensor_layout(std::make_shared(tensor_layout)); } -void CoverSliceShape(const FuncGraphPtr& root) { +void CoverSliceShape(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); auto parameters = root->parameters(); - for (auto& parameter : parameters) { + for (auto ¶meter : parameters) { MS_EXCEPTION_IF_NULL(parameter->Shape()); auto iter = g_RefMap.find(parameter); if (iter != g_RefMap.end()) { @@ -1258,7 +1258,7 @@ void CoverSliceShape(const FuncGraphPtr& root) { g_RefMap.clear(); } -bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_node) { +bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr ¶meter_node) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(parameter_node); FuncGraphManagerPtr manager = root->manager(); @@ -1281,9 +1281,9 @@ bool ParameterIsCloned(const FuncGraphPtr& root, const AnfNodePtr& parameter_nod return true; } -void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { +void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); - for (auto& cloned_parameter_node : root->parameters()) { + for (auto &cloned_parameter_node : root->parameters()) { MS_EXCEPTION_IF_NULL(cloned_parameter_node); auto cloned_parameter = cloned_parameter_node->cast(); MS_EXCEPTION_IF_NULL(cloned_parameter); @@ -1300,7 +1300,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { bool found_be_cloned_parameter = false; ParameterPtr cloned_from_parameter = nullptr; AnfNodePtr cloned_from_node = nullptr; - for (auto& be_cloned_parameter_node : root->parameters()) { + for (auto &be_cloned_parameter_node : root->parameters()) { MS_EXCEPTION_IF_NULL(be_cloned_parameter_node); auto be_cloned_parameter = be_cloned_parameter_node->cast(); MS_EXCEPTION_IF_NULL(be_cloned_parameter); @@ -1315,7 +1315,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { // get the be cloned index py::list be_cloned_index = parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED_INDEX); - for (auto& index : be_cloned_index) { + for (auto &index : be_cloned_index) { if (cloned_index == py::cast(index)) { found_be_cloned_parameter = true; cloned_from_parameter = be_cloned_parameter; @@ -1341,7 +1341,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr& root) { } } -void SetVirtualDatasetStrategy(const CNodePtr& node) { +void SetVirtualDatasetStrategy(const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); PrimitivePtr prim = GetValueNode(node->input(0)); MS_EXCEPTION_IF_NULL(prim); @@ -1370,8 +1370,8 @@ void SetVirtualDatasetStrategy(const CNodePtr& node) { } } -void ExtractInformation(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void ExtractInformation(const std::vector &all_nodes) { + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; @@ -1390,7 +1390,7 @@ void ExtractInformation(const std::vector& all_nodes) { if (operator_ == nullptr) { MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed"; } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); std::vector input_value; for (size_t index = 1; index < inputs.size(); ++index) { if (inputs[index]->isa()) { @@ -1440,7 +1440,7 @@ void ExtractInformation(const std::vector& all_nodes) { } } -TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair) { +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair) { CNodePtr cnode = node_pair.first->cast(); MS_EXCEPTION_IF_NULL(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); @@ -1456,13 +1456,13 @@ TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair } // if reshape's output connect to several primitive, return the first layout found -std::shared_ptr FindNextLayout(const CNodePtr& cnode) { +std::shared_ptr FindNextLayout(const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode->func_graph()); FuncGraphManagerPtr manager = cnode->func_graph()->manager(); MS_EXCEPTION_IF_NULL(manager); AnfNodeIndexSet node_set = manager->node_users()[cnode]; - for (auto& node_pair : node_set) { + for (auto &node_pair : node_set) { CNodePtr use_apply = node_pair.first->cast(); if (use_apply == nullptr || !IsValueNode(use_apply->input(0))) { continue; @@ -1492,7 +1492,7 @@ std::shared_ptr FindNextLayout(const CNodePtr& cnode) { return nullptr; } -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, size_t output_index) { +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index) { MS_EXCEPTION_IF_NULL(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); MS_EXCEPTION_IF_NULL(distribute_operator); @@ -1505,7 +1505,7 @@ std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, si return std::make_shared(tensorlayout_out); } -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& node, size_t output_index) { +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index) { if (!node->isa()) { return nullptr; } @@ -1523,7 +1523,7 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& n return nullptr; } -std::shared_ptr FindPrevLayout(const AnfNodePtr& node) { +std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (node->isa()) { MS_LOG(EXCEPTION) << "Failure: parameter before reshape is not supported temporary"; } @@ -1567,8 +1567,8 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr& node) { return nullptr; } -void ReshapeInit(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void ReshapeInit(const std::vector &all_nodes) { + for (auto &node : all_nodes) { auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { continue; @@ -1608,7 +1608,7 @@ void ReshapeInit(const std::vector& all_nodes) { } // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -bool IsGradSensNode(const AnfNodePtr& node) { +bool IsGradSensNode(const AnfNodePtr &node) { if (!node->isa()) { return false; } @@ -1660,7 +1660,7 @@ bool IsGradSensNode(const AnfNodePtr& node) { return (expect_j_prim->name() == J); } -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr& loss_cnode) { +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { MS_EXCEPTION_IF_NULL(loss_cnode); AnfNodePtr node = loss_cnode->cast(); MS_EXCEPTION_IF_NULL(node); @@ -1700,7 +1700,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr& loss_cnode) { return ret; } -void SplitSens(const AnfNodePtr& grad_sens_node, const TensorLayout& loss_grad_layout) { +void SplitSens(const AnfNodePtr &grad_sens_node, const TensorLayout &loss_grad_layout) { MS_EXCEPTION_IF_NULL(grad_sens_node); auto cnode = grad_sens_node->cast(); @@ -1752,7 +1752,7 @@ void SplitSens(const AnfNodePtr& grad_sens_node, const TensorLayout& loss_grad_l InsertGetTensorSliceOp(op, cnode, func_graph, 1, SPLIT_SENS); } -void InsertForwardOps(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void InsertForwardOps(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); OperatorVector forward_op = distribute_operator->forward_op(); @@ -1762,7 +1762,7 @@ void InsertForwardOps(const OperatorInfoPtr& distribute_operator, const CNodePtr } } -void StepReplace(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void StepReplace(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); // StepReplaceOp @@ -1783,7 +1783,7 @@ void StepReplace(const OperatorInfoPtr& distribute_operator, const CNodePtr& cno } } -void HandleDropoutNode(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(cnode); @@ -1801,12 +1801,12 @@ void HandleDropoutNode(const OperatorInfoPtr& distribute_operator, const CNodePt ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast()); } -void HandleSpecialNode(const OperatorInfoPtr& distribute_operator, const CNodePtr& cnode) { +void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) { HandleDropoutNode(distribute_operator, cnode); } -void ParallelCommunication(const FuncGraphPtr& root, const std::vector& all_nodes, - const FuncGraphManagerPtr& manager) { +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(manager); TensorRedistribution tensor_redistribution; @@ -1817,7 +1817,7 @@ void ParallelCommunication(const FuncGraphPtr& root, const std::vector(node); MS_EXCEPTION_IF_NULL(symbolic_key); auto all_upstream_node = root->manager()->node_users()[node]; - for (auto& upstream_node : all_upstream_node) { + for (auto &upstream_node : all_upstream_node) { FuncGraphPtr fg = upstream_node.first->func_graph(); if (symbolic_key->node()->isa()) { - for (auto& param : root->parameters()) { + for (auto ¶m : root->parameters()) { if (*param == *symbolic_key->node()) { AnfNodePtr reverted_node = root->NewCNode({NewValueNode(prim::kPrimEmbed), param}); MS_EXCEPTION_IF_NULL(reverted_node); @@ -1889,9 +1889,9 @@ void RevertSymbolicKeyInstance(const FuncGraphPtr& root, const AnfNodePtr& node) } } // namespace -void HandleSymbolicKeyInstance(const FuncGraphPtr& root, const std::vector& all_nodes) { +void HandleSymbolicKeyInstance(const FuncGraphPtr &root, const std::vector &all_nodes) { MS_EXCEPTION_IF_NULL(root); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { // revert back SymbolicKeyInstance to embed() primitive if (IsValueNode(node)) { RevertSymbolicKeyInstance(root, node); @@ -1900,13 +1900,13 @@ void HandleSymbolicKeyInstance(const FuncGraphPtr& root, const std::vectorget_return(); auto all_nodes = DeepScopedGraphSearch(ret); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1931,7 +1931,7 @@ void CheckpointStrategy(const FuncGraphPtr& func_graph) { } } -void RestoreStrategy(const FuncGraphPtr& func_graph) { +void RestoreStrategy(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); MS_LOG(INFO) << "Extract strategy from checkpoint begin"; StrategyMap straMap; @@ -1943,7 +1943,7 @@ void RestoreStrategy(const FuncGraphPtr& func_graph) { } auto ret = func_graph->get_return(); auto all_nodes = DeepScopedGraphSearch(ret); - for (auto& node : all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); auto cnode = node->cast(); if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { @@ -1968,8 +1968,8 @@ void RestoreStrategy(const FuncGraphPtr& func_graph) { } } -void SetForwardFlag(const std::vector& all_nodes) { - for (auto& node : all_nodes) { +void SetForwardFlag(const std::vector &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -1986,8 +1986,8 @@ void SetForwardFlag(const std::vector& all_nodes) { } } -void SetForwardFlag(const AnfNodeSet& all_nodes) { - for (auto& node : all_nodes) { +void SetForwardFlag(const AnfNodeSet &all_nodes) { + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -2003,7 +2003,7 @@ void SetForwardFlag(const AnfNodeSet& all_nodes) { } } -CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { +CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); CNodePtr return_node = func_graph->get_return(); MS_EXCEPTION_IF_NULL(return_node); @@ -2059,8 +2059,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { return pre_cnode; } -FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet& root_all_nodes) { - for (auto& node : root_all_nodes) { +FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { + for (auto &node : root_all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; @@ -2088,11 +2088,11 @@ FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet& root_all_nodes) { return nullptr; } -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root) { +CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); AnfNodePtr root_return_node = root->get_return(); MS_EXCEPTION_IF_NULL(root_return_node); - const auto& all_nodes = root->nodes(); + const auto &all_nodes = root->nodes(); FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); if (func_graph == nullptr) { return FindLossCNode(root); @@ -2101,12 +2101,12 @@ CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root) { } } -FuncGraphPtr ForwardGraph(const FuncGraphPtr& root) { +FuncGraphPtr ForwardGraph(const FuncGraphPtr &root) { FuncGraphPtr forward_graph = root; MS_EXCEPTION_IF_NULL(root); AnfNodePtr root_return_node = root->get_return(); MS_EXCEPTION_IF_NULL(root_return_node); - const auto& all_nodes = root->nodes(); + const auto &all_nodes = root->nodes(); FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); if (func_graph != nullptr) { forward_graph = func_graph; @@ -2114,11 +2114,11 @@ FuncGraphPtr ForwardGraph(const FuncGraphPtr& root) { return forward_graph; } -void MarkForwardCNode(const FuncGraphPtr& root) { +void MarkForwardCNode(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); AnfNodePtr root_return_node = root->get_return(); MS_EXCEPTION_IF_NULL(root_return_node); - auto& all_nodes = root->nodes(); + auto &all_nodes = root->nodes(); FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); if (func_graph == nullptr) { @@ -2178,7 +2178,7 @@ Status ParallelInit() { return SUCCESS; } -bool StepParallel(const FuncGraphPtr& root, const opt::OptimizerPtr& optimizer) { +bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(optimizer); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); @@ -2258,12 +2258,12 @@ bool StepParallel(const FuncGraphPtr& root, const opt::OptimizerPtr& optimizer) } // Needed by rec_parser -std::vector ExtractInputsTensorName(const CNodePtr& node) { +std::vector ExtractInputsTensorName(const CNodePtr &node) { std::vector name_inputs; std::vector all_inputs = node->inputs(); std::vector node_inputs{all_inputs.begin() + 1, all_inputs.end()}; - for (auto& input : node_inputs) { + for (auto &input : node_inputs) { std::string name; if (IsValueNode(input) || input->isa() || input->isa()) { name = input->ToString(); diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index fd47a59bf5..184d11d173 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -41,114 +41,114 @@ struct LossNodeInfo { int dout_index = 0; // now don't support the sens is a tuple }; -std::vector CreateInput(const Operator& op, const AnfNodePtr& node, const std::string& instance_name); -std::string CreateInstanceName(const CNodePtr& node, size_t index); -void ForwardCommunication(OperatorVector forward_op, const CNodePtr& node); +std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name); +std::string CreateInstanceName(const CNodePtr &node, size_t index); +void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node); -void InsertRedistribution(const RedistributionOpListPtr& redistribution_oplist_ptr, const CNodePtr& node, - const FuncGraphPtr& func_graph, int pos, const CNodePtr& pre_node); +void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_ptr, const CNodePtr &node, + const FuncGraphPtr &func_graph, int pos, const CNodePtr &pre_node); -TensorLayout GetTensorInLayout(const CNodePtr& pre_node, const PrimitivePtr& pre_prim, - const OperatorInfoPtr& distribute_operator_pre); +TensorLayout GetTensorInLayout(const CNodePtr &pre_node, const PrimitivePtr &pre_prim, + const OperatorInfoPtr &distribute_operator_pre); -OperatorInfoPtr GetDistributeOperator(const CNodePtr& node); +OperatorInfoPtr GetDistributeOperator(const CNodePtr &node); -void Redistribution(const std::pair& node_pair, const OperatorInfoPtr& distribute_operator, - const CNodePtr& middle_node, int index, TensorRedistribution tensor_redistribution, - const CNodePtr& pre_node); +void Redistribution(const std::pair &node_pair, const OperatorInfoPtr &distribute_operator, + const CNodePtr &middle_node, int index, TensorRedistribution tensor_redistribution, + const CNodePtr &pre_node); bool StrategyFound(std::unordered_map attrs); -bool IsParallelCareNode(const CNodePtr& cnode); +bool IsParallelCareNode(const CNodePtr &cnode); -void MarkForwardCNode(const FuncGraphPtr& root); +void MarkForwardCNode(const FuncGraphPtr &root); -bool FindCommunicationOp(const std::vector& all_nodes); +bool FindCommunicationOp(const std::vector &all_nodes); -void StepRedistribution(const CNodePtr& node, const OperatorInfoPtr& distribute_operator, const CNodePtr& insert_node, - const TensorRedistribution& tensor_redistribution, const CNodePtr& pre_node); +void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_operator, const CNodePtr &insert_node, + const TensorRedistribution &tensor_redistribution, const CNodePtr &pre_node); -std::vector ReplaceOpInput(const Operator& replace_op, const std::string& instance_name, - const CNodePtr& node); +std::vector ReplaceOpInput(const Operator &replace_op, const std::string &instance_name, + const CNodePtr &node); -void StepReplaceOp(OperatorVector replace_op, const CNodePtr& node); +void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node); -void InsertVirtualDivOp(const VirtualDivOp& virtual_div_op, const CNodePtr& node); +void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node); -std::pair FindParameter(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +std::pair FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph); -std::pair FindCNode(const AnfNodePtr& anode, const std::string& name, const FuncGraphPtr& func_graph); +std::pair FindCNode(const AnfNodePtr &anode, const std::string &name, const FuncGraphPtr &func_graph); -void InsertMirrorOps(const MirrorOps& mirror_ops, const CNodePtr& node); +void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node); -void BackwardCommunication(const OperatorInfoPtr& distribute_operator, const CNodePtr& node, bool is_loss_node); +void BackwardCommunication(const OperatorInfoPtr &distribute_operator, const CNodePtr &node, bool is_loss_node); // Generate and init parallel operator -OperatorInfoPtr OperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, - const std::vector& shape_list); +OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, + const std::vector &shape_list); // Generate without initing parallel operator -OperatorInfoPtr NewOperatorInstance(const PrimitivePtr& prim, const PrimitiveAttrs& attrs, +OperatorInfoPtr NewOperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs &attrs, std::vector shape_list); // Extract strategy from attr StrategyPtr ExtractStrategy(std::unordered_map attrs); -Shapes GetNodeShape(const AnfNodePtr& node); +Shapes GetNodeShape(const AnfNodePtr &node); -std::vector FindParameterByRefKeyNode(const AnfNodePtr& node, const FuncGraphPtr& func_graph); +std::vector FindParameterByRefKeyNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph); // Extract shape from anfnode -std::vector ExtractShape(const CNodePtr& node); +std::vector ExtractShape(const CNodePtr &node); -std::pair FindParallelCareNode(const AnfNodePtr& node); +std::pair FindParallelCareNode(const AnfNodePtr &node); // Find finally sub graph -std::pair FindSubGraph(const FuncGraphPtr& func_graph, const AnfNodePtr& parameter); +std::pair FindSubGraph(const FuncGraphPtr &func_graph, const AnfNodePtr ¶meter); // Set distribute shape for parameters abstract -void SetParallelShape(const AnfNodePtr& parameter, const std::pair& res); +void SetParallelShape(const AnfNodePtr ¶meter, const std::pair &res); // change parameters'shape in resource -void CoverSliceShape(const FuncGraphPtr& root); +void CoverSliceShape(const FuncGraphPtr &root); -void SetVirtualDatasetStrategy(const CNodePtr& node); +void SetVirtualDatasetStrategy(const CNodePtr &node); // Creat parallel operator for primitive node(has strategy) -void ExtractInformation(const std::vector& all_nodes); +void ExtractInformation(const std::vector &all_nodes); -TensorLayout GetInputLayoutFromCNode(const std::pair& node_pair); +TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); -std::shared_ptr FindNextLayout(const CNodePtr& node); +std::shared_ptr FindNextLayout(const CNodePtr &node); -std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr& cnode, size_t output_index); +std::shared_ptr GetOutputLayoutFromCNode(const CNodePtr &cnode, size_t output_index); -std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr& node, size_t output_index); +std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &node, size_t output_index); -std::shared_ptr FindPrevLayout(const AnfNodePtr& node); +std::shared_ptr FindPrevLayout(const AnfNodePtr &node); -void ReshapeInit(const std::vector& all_nodes); +void ReshapeInit(const std::vector &all_nodes); // Add node for whole graph -void ParallelCommunication(const FuncGraphPtr& root, const std::vector& all_nodes, - const FuncGraphManagerPtr& manager); +void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, + const FuncGraphManagerPtr &manager); -void RestoreStrategy(const FuncGraphPtr& func_graph); +void RestoreStrategy(const FuncGraphPtr &func_graph); -void CheckpointStrategy(const FuncGraphPtr& func_graph); +void CheckpointStrategy(const FuncGraphPtr &func_graph); // main step of Parallel -bool StepParallel(const FuncGraphPtr& func_graph, const opt::OptimizerPtr& optimizer); +bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optimizer); -int32_t GetTupleGetItemIndex(const CNodePtr& cnode); +int32_t GetTupleGetItemIndex(const CNodePtr &cnode); -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr& root); +CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root); Status ParallelInit(); -std::vector ExtractInputsTensorName(const CNodePtr& node); +std::vector ExtractInputsTensorName(const CNodePtr &node); -FuncGraphPtr ForwardGraph(const FuncGraphPtr& root); +FuncGraphPtr ForwardGraph(const FuncGraphPtr &root); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/strategy.h b/mindspore/ccsrc/parallel/strategy.h index 93d4d4dff1..fce99305a5 100644 --- a/mindspore/ccsrc/parallel/strategy.h +++ b/mindspore/ccsrc/parallel/strategy.h @@ -46,7 +46,7 @@ class Strategy { inputs_.push_back(inputs_[0]); } } - void ResetInputs(const std::vector& input) { inputs_ = input; } + void ResetInputs(const std::vector &input) { inputs_ = input; } private: const int32_t stage_; @@ -55,7 +55,7 @@ class Strategy { std::vector inputs_; }; -inline StrategyPtr NewStrategy(const int32_t stage, const std::vector& inputs) { +inline StrategyPtr NewStrategy(const int32_t stage, const std::vector &inputs) { return std::make_shared(stage, inputs); } } // namespace parallel diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 9e3573eee2..dd518dc76c 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -27,7 +27,7 @@ namespace mindspore { namespace parallel { -StrategyCheckpoint& StrategyCheckpoint::GetInstance() { +StrategyCheckpoint &StrategyCheckpoint::GetInstance() { static StrategyCheckpoint instance = StrategyCheckpoint(); return instance; } @@ -47,7 +47,7 @@ Status StrategyCheckpoint::RemoveCheckPoint() const { return FAILED; } -Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { +Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { if (strategy_map == nullptr) { MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; } @@ -82,18 +82,18 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } -Status StrategyCheckpoint::Save(const StrategyMap& strategy_map) { +Status StrategyCheckpoint::Save(const StrategyMap &strategy_map) { straspb::ParallelStrategyMap parallel_strategy_map; parallel_strategy_map.set_train_time(IntToUint(++current_train_time_)); - for (auto& node_stra : strategy_map) { - straspb::ParallelStrategyItem* parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); + for (auto &node_stra : strategy_map) { + straspb::ParallelStrategyItem *parallel_strategy_item = parallel_strategy_map.add_parallel_strategy_item(); MS_EXCEPTION_IF_NULL(parallel_strategy_item); parallel_strategy_item->set_node_name(node_stra.first); - straspb::ParallelStrategys* parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); + straspb::ParallelStrategys *parallel_strategys = parallel_strategy_item->mutable_parallel_strategys(); MS_EXCEPTION_IF_NULL(parallel_strategys); parallel_strategys->set_stage(IntToUint(node_stra.second->GetInputStage())); - for (auto& dims : node_stra.second->GetInputDim()) { - straspb::ParallelStrategy* parallel_strategy = parallel_strategys->add_parallel_strategy(); + for (auto &dims : node_stra.second->GetInputDim()) { + straspb::ParallelStrategy *parallel_strategy = parallel_strategys->add_parallel_strategy(); MS_EXCEPTION_IF_NULL(parallel_strategy); for (auto dim : dims) { parallel_strategy->add_dim(IntToUint(dim)); diff --git a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index b5d3626f53..c871ea6eef 100644 --- a/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -32,11 +32,11 @@ class StrategyCheckpoint { StrategyCheckpoint() : path_(DEFAULT_CHECKPOINT_PATH), current_train_time_(1) { train_times_ = 1; checkpoint_on_ = false; - const char* train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); + const char *train_times_str = std::getenv("PARALLEL_TRAIN_TIMES"); if (train_times_str != nullptr && std::stoi(train_times_str) > 0) { train_times_ = std::stoi(train_times_str); } - const char* checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); + const char *checkpoint_on_str = std::getenv("PARALLEL_CHECKPOINT_ON"); if (checkpoint_on_str != nullptr) { checkpoint_on_ = (std::string(checkpoint_on_str) == "on"); } @@ -44,10 +44,10 @@ class StrategyCheckpoint { ~StrategyCheckpoint() = default; bool CheckPointExit() const; Status RemoveCheckPoint() const; - Status Load(StrategyMap* strategy_map); - Status Save(const StrategyMap& strategy_map); + Status Load(StrategyMap *strategy_map); + Status Save(const StrategyMap &strategy_map); - static StrategyCheckpoint& GetInstance(); + static StrategyCheckpoint &GetInstance(); int32_t GetTrainTimes() const { return train_times_; } int32_t GetCurrentTrainTime() const { return current_train_time_; } bool CheckPointOn() const { return checkpoint_on_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc index b42ba30242..235ab00302 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/arrangement.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Arrangement::Init(const std::vector& array) { +Status Arrangement::Init(const std::vector &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -45,7 +45,7 @@ bool Arrangement::IsValidArrangement() { void Arrangement::ComputeSize() { size_ = 1; - for (auto& value : array_) { + for (auto &value : array_) { size_ *= value; } } @@ -84,7 +84,7 @@ std::vector Arrangement::GetFrontElementByValue(int32_t value) const { } std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft( - const std::vector& expand_list) const { + const std::vector &expand_list) const { if (expand_list.size() != GetDimSize()) { return nullptr; } @@ -108,7 +108,7 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListRemoveLeft * array_ = [8, 4], * arrangement_list = [[4, 2], [2, 2]] */ -std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement& expand_shape) const { +std::shared_ptr> Arrangement::GetExpandShapeList(const Arrangement &expand_shape) const { int32_t size = 1; uint32_t ind = 0; std::vector arrangement_list; @@ -140,7 +140,7 @@ std::shared_ptr> Arrangement::GetExpandShapeList(const } std::shared_ptr, Arrangement>> Arrangement::GetExpandShapeListPair( - const Arrangement& expand_shape) const { + const Arrangement &expand_shape) const { std::shared_ptr> expand_shape_list_ptr = GetExpandShapeList(expand_shape); if (expand_shape_list_ptr == nullptr) { return nullptr; @@ -148,7 +148,7 @@ std::shared_ptr, Arrangement>> Arrangement::G std::vector expand_num_list_shape; (void)std::transform(expand_shape_list_ptr->begin(), expand_shape_list_ptr->end(), std::back_inserter(expand_num_list_shape), - [](const Arrangement& arr) { return SizeToInt(arr.GetDimSize()); }); + [](const Arrangement &arr) { return SizeToInt(arr.GetDimSize()); }); Arrangement expand_num_list; Status status = expand_num_list.Init(expand_num_list_shape); if (status != Status::SUCCESS) { @@ -169,7 +169,7 @@ std::vector Arrangement::ComputeReverseAccumulateSumInReverseOrder() co } std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLeft( - const std::vector& expand_list) const { + const std::vector &expand_list) const { if (expand_list.size() != GetDimSize()) { return nullptr; } @@ -191,7 +191,7 @@ std::shared_ptr Arrangement::GetExpandedShapeByExpandListReserveLef return std::make_shared(arrangement_new); } -std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement& in2) const { +std::shared_ptr Arrangement::GetUnifiedShape(const Arrangement &in2) const { std::vector in1_accum; Status status = ShapeToAccumulateProduct(array_, &in1_accum); if (status != Status::SUCCESS) { diff --git a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h index 2dc13038c1..ca71b05c91 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/arrangement.h +++ b/mindspore/ccsrc/parallel/tensor_layout/arrangement.h @@ -32,18 +32,18 @@ class Arrangement : public Array { public: Arrangement() : size_(1) {} ~Arrangement() override = default; - Status Init(const std::vector& array) override; + Status Init(const std::vector &array) override; int32_t size() const { return size_; } std::vector GetFrontElementByValue(int32_t value) const; - std::shared_ptr> GetExpandShapeList(const Arrangement& expand_shape) const; + std::shared_ptr> GetExpandShapeList(const Arrangement &expand_shape) const; std::vector ComputeReverseAccumulateSumInReverseOrder() const; std::shared_ptr GetExpandedShapeByExpandListReserveLeft( - const std::vector& expand_list) const; + const std::vector &expand_list) const; std::shared_ptr GetExpandedShapeByExpandListRemoveLeft( - const std::vector& expand_list) const; + const std::vector &expand_list) const; std::shared_ptr, Arrangement>> GetExpandShapeListPair( - const Arrangement& expand_shape) const; - std::shared_ptr GetUnifiedShape(const Arrangement& in2) const; + const Arrangement &expand_shape) const; + std::shared_ptr GetUnifiedShape(const Arrangement &in2) const; std::vector GetSqueezeIdx() const; Arrangement GetSqueezeArrangement() const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.cc b/mindspore/ccsrc/parallel/tensor_layout/array.cc index ba3858ae00..ef358e7cde 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/array.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/array.cc @@ -24,14 +24,14 @@ namespace parallel { std::string Array::ToString() const { std::ostringstream buffer; buffer << "[ "; - for (auto& element : array_) { + for (auto &element : array_) { buffer << std::to_string(element) + " "; } buffer << "]"; return buffer.str(); } -Status Array::Init(const std::vector& array) { +Status Array::Init(const std::vector &array) { array_ = array; return IsvalidArray() ? Status::SUCCESS : Status::FAILED; } @@ -54,7 +54,7 @@ int32_t Array::GetDimByReverseIdx(uint32_t idx) const { return array_[GetDimSize() - 1 - mod_idx]; } -bool Array::operator==(const Array& shape) const { +bool Array::operator==(const Array &shape) const { if (GetDimSize() != shape.GetDimSize()) { return false; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/array.h b/mindspore/ccsrc/parallel/tensor_layout/array.h index f7d9c3c673..5aa3bdb138 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/array.h +++ b/mindspore/ccsrc/parallel/tensor_layout/array.h @@ -31,13 +31,13 @@ class Array { Array() = default; virtual ~Array() = default; std::string ToString() const; - virtual Status Init(const std::vector& array); + virtual Status Init(const std::vector &array); bool IsvalidArray() const; std::vector array() const { return array_; } size_t GetDimSize() const { return array_.size(); } int32_t GetDimByIdx(uint32_t idx) const; int32_t GetDimByReverseIdx(uint32_t idx) const; - bool operator==(const Array& a1) const; + bool operator==(const Array &a1) const; protected: std::vector array_; diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc index 829c056fc2..b5ca5ed60a 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace parallel { -Status ConstructOperator::Init(const RankList& dev_list, const Shape& dev_matrix_shape) { +Status ConstructOperator::Init(const RankList &dev_list, const Shape &dev_matrix_shape) { dev_size_ = dev_matrix_shape.size(); dev_matrix_shape_ = dev_matrix_shape; dev_list_ = dev_list; @@ -46,7 +46,7 @@ Status ConstructOperator::ReshapeOP(Shape shape) { return Status::SUCCESS; } -Operator CreateStridedSliceOp(int32_t value, const Shape& begin, const Shape& end, const Shape& strides) { +Operator CreateStridedSliceOp(int32_t value, const Shape &begin, const Shape &end, const Shape &strides) { ValuePtr attr_value = MakeValue(value); Attr attr_begin_mask = std::make_pair(BEGIN_MASK, attr_value); Attr attr_end_mask = std::make_pair(END_MASK, attr_value); @@ -230,7 +230,7 @@ Status ConstructOperator::AlltoAllOP(Args args) { return Status::SUCCESS; } -Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector* group) { +Status ConstructOperator::CreateGroupByDim(size_t axis, std::vector *group) { MS_EXCEPTION_IF_NULL(group); CheckGlobalDeviceManager(); MS_EXCEPTION_IF_NULL(g_device_manager); diff --git a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h index cf6cff456a..1a69638fb6 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h +++ b/mindspore/ccsrc/parallel/tensor_layout/construct_operator.h @@ -34,7 +34,7 @@ class ConstructOperator { const int32_t DEFAULT = 0; ConstructOperator() : dev_size_(0) {} ~ConstructOperator() = default; - Status Init(const RankList& dev_list, const Shape& dev_matrix_shape); + Status Init(const RankList &dev_list, const Shape &dev_matrix_shape); Status ReshapeOP(Shape shape); Status StridedSliceOP(Args args); Status AllGatherOP(int32_t dev_dim); @@ -42,7 +42,7 @@ class ConstructOperator { Status ConcatOP(int32_t concat_dim); Status AlltoAllOP(Args args); Operator GetOperator() const { return op_; } - void UpdateTensorShape(const Shape& tensor_shape) { tensor_shape_ = tensor_shape; } + void UpdateTensorShape(const Shape &tensor_shape) { tensor_shape_ = tensor_shape; } private: Operator op_; @@ -50,7 +50,7 @@ class ConstructOperator { Shape tensor_shape_; RankList dev_list_; Shape dev_matrix_shape_; - Status CreateGroupByDim(size_t axis, std::vector* group); + Status CreateGroupByDim(size_t axis, std::vector *group); }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc index 190a5846ba..84c0580ba8 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.cc @@ -29,7 +29,7 @@ std::string LayoutTransfer::ToString() const { LayoutTransfer::~LayoutTransfer() = default; -Status LayoutTransfer::Init(const TensorLayout& from_in, const TensorLayout& to_in) { +Status LayoutTransfer::Init(const TensorLayout &from_in, const TensorLayout &to_in) { from_in_ = from_in; to_in_ = to_in; MS_LOG(DEBUG) << "LayoutTransfer " << this->ToString(); diff --git a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h index b05128f5b8..c4da4b728f 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/layout_transfer.h @@ -28,7 +28,7 @@ class LayoutTransfer { LayoutTransfer() = default; virtual ~LayoutTransfer() = 0; std::string ToString() const; - Status Init(const TensorLayout& from_in, const TensorLayout& to_in); + Status Init(const TensorLayout &from_in, const TensorLayout &to_in); TensorLayout from_in() const { return from_in_; } TensorLayout to_in() const { return to_in_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.cc b/mindspore/ccsrc/parallel/tensor_layout/map.cc index 320dbe6ebd..669920fc44 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/map.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/map.cc @@ -26,7 +26,7 @@ namespace mindspore { namespace parallel { -Status Map::Init(const std::vector& array) { +Status Map::Init(const std::vector &array) { Status status = Array::Init(array); if (status != Status::SUCCESS) { return Status::FAILED; @@ -46,7 +46,7 @@ bool Map::IsValidMap() { std::vector sorted_array = array_; std::sort(sorted_array.begin(), sorted_array.end()); int32_t value = MAP_NONE; - for (auto& element : sorted_array) { + for (auto &element : sorted_array) { if (element == MAP_NONE) { continue; } @@ -78,7 +78,7 @@ int32_t Map::GetIndexByValue(int32_t value) const { /* * expand.size() should be equal to array_.size() */ -std::shared_ptr Map::ExpandMapByNone(const Arrangement& expand_num_list) const { +std::shared_ptr Map::ExpandMapByNone(const Arrangement &expand_num_list) const { if (expand_num_list.GetDimSize() != GetDimSize()) { return nullptr; } @@ -105,7 +105,7 @@ std::shared_ptr Map::ExpandMapByNone(const Arrangement& expand_num_list) co /* * expand.size() should be equal to array_.size() */ -std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement& expand_num_list) const { +std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const { if (GetMaxItem() >= static_cast(expand_num_list.GetDimSize())) { return nullptr; } @@ -126,7 +126,7 @@ std::shared_ptr Map::ExpandMapByDecreaseNumber(const Arrangement& expand_nu return map_new; } -std::shared_ptr> Map::ReMapVector(const std::vector& input_vector) const { +std::shared_ptr> Map::ReMapVector(const std::vector &input_vector) const { if (GetMaxItem() >= static_cast(input_vector.size())) { return nullptr; } @@ -143,7 +143,7 @@ std::shared_ptr> Map::ReMapVector(const std::vector idx_list) const { - for (auto& value : idx_list) { + for (auto &value : idx_list) { if (GetDimByIdx(SizeToUint(value)) != MAP_NONE) { return false; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/map.h b/mindspore/ccsrc/parallel/tensor_layout/map.h index 3f839ef198..8c8bba2775 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/map.h +++ b/mindspore/ccsrc/parallel/tensor_layout/map.h @@ -34,12 +34,12 @@ class Map : public Array { public: Map() = default; ~Map() override = default; - Status Init(const std::vector& array) override; + Status Init(const std::vector &array) override; int32_t GetMaxItem() const; int32_t GetIndexByValue(int32_t value) const; - std::shared_ptr ExpandMapByNone(const Arrangement& expand_num_list) const; - std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement& expand_num_list) const; - std::shared_ptr> ReMapVector(const std::vector& input_vector) const; + std::shared_ptr ExpandMapByNone(const Arrangement &expand_num_list) const; + std::shared_ptr ExpandMapByDecreaseNumber(const Arrangement &expand_num_list) const; + std::shared_ptr> ReMapVector(const std::vector &input_vector) const; bool CheckNoneByIdxList(std::vector idx_list) const; Map SqueezeMapByIdxList(std::vector idx_list) const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc index ac768c19f9..946620ec4c 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.cc @@ -22,7 +22,7 @@ namespace mindspore { namespace parallel { -Status RedistributionOperatorInfer::Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, +Status RedistributionOperatorInfer::Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, bool is_cost_model) { in_tensor_map_ = tensor_layout.tensor_map(); dev_mat_ = tensor_layout.device_arrangement(); @@ -105,7 +105,7 @@ Status RedistributionOperatorInfer::InferSplitByAxis() { } if (in_dim == NONE && !std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { Args args = {dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)), UintToInt(index), out_dim}; if (InsertOperator(SPLIT_BY_AXIS, args) == Status::FAILED) { MS_LOG(ERROR) << "Insert SplitByAxis Error!"; @@ -130,7 +130,7 @@ Status RedistributionOperatorInfer::InferPermuteByAxis() { } if (in_dim == NONE && std::any_of(map_.begin(), map_.end(), - [out_dim](const RedistributionOperatorMap::value_type& a) { return a.second == out_dim; })) { + [out_dim](const RedistributionOperatorMap::value_type &a) { return a.second == out_dim; })) { int32_t cat_dim = in_tensor_map_.GetIndexByValue(out_dim); int32_t dev_num = dev_mat_.GetDimByReverseIdx(IntToUint(out_dim)); if (is_cost_model_) { diff --git a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h index 8fd953572a..a96097a1d3 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/redistribution_operator_infer.h @@ -40,7 +40,7 @@ class RedistributionOperatorInfer { public: const int NONE = -1; explicit RedistributionOperatorInfer(bool construct_op_flag = true) : construct_op_flag_(construct_op_flag) {} - Status Init(const TensorLayout& tensor_layout, const Map& out_tensor_map, RankList dev_list, + Status Init(const TensorLayout &tensor_layout, const Map &out_tensor_map, RankList dev_list, bool is_cost_model = false); ~RedistributionOperatorInfer() = default; OperatorList operator_list() const { return operator_list_; } diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc index 39a6bef92d..f6c90e9d46 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.cc @@ -104,7 +104,7 @@ std::shared_ptr ReshapeLayoutTransfer::ExchangeFromAndTo( } std::shared_ptr ReshapeLayoutTransfer::ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement& expand_shape) const { + const Arrangement &expand_shape) const { std::shared_ptr extend_tensor_shape_from_ptr = from_in_.ExpandTensorShape(expand_shape); if (extend_tensor_shape_from_ptr == nullptr) { return nullptr; diff --git a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h index 8aae71631d..ed62cb59da 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h +++ b/mindspore/ccsrc/parallel/tensor_layout/reshape_layout_transfer.h @@ -33,7 +33,7 @@ class ReshapeLayoutTransfer : public LayoutTransfer { std::shared_ptr ExtendFromTensorShapeByExpandedTensorShape() const; std::shared_ptr ExtendToTensorShapeByExpandedTensorShape() const; std::shared_ptr ExpandFromTensorShapeAndExpandToDeviceArrangement( - const Arrangement& expand_shape) const; + const Arrangement &expand_shape) const; std::shared_ptr ExchangeFromAndTo() const; private: diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc index a26627fb3c..e8f208708c 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/shape_util.cc @@ -26,7 +26,7 @@ namespace parallel { * shape = [2, 8, 32] * shape_accum = [2, 2 * 8, 2 * 8 * 32] */ -Status ShapeToAccumulateProduct(const std::vector& shape, std::vector* shape_accum) { +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -47,7 +47,7 @@ Status ShapeToAccumulateProduct(const std::vector& shape, std::vector& shape, std::vector* shape_accum) { +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum) { MS_EXCEPTION_IF_NULL(shape_accum); shape_accum->clear(); int64_t size = 1; @@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector& shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector& shape_accum, std::vector* shape) { +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -92,7 +92,7 @@ Status AccumulateProductToShape(const std::vector& shape_accum, std::ve * shape_accum_reverse = [2 * 8 * 32, 8 * 32, 32] * shape = [2, 8, 32] */ -Status AccumulateProductReverseToShape(const std::vector& shape_accum_reverse, std::vector* shape) { +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape) { MS_EXCEPTION_IF_NULL(shape); shape->clear(); int64_t value = 1; @@ -122,8 +122,8 @@ Status AccumulateProductReverseToShape(const std::vector& shape_accum_r * in2 = [8, 16] * *out = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector& in1_accum, const std::vector& in2_accum, - std::vector* out_accum) { +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum) { MS_EXCEPTION_IF_NULL(out_accum); out_accum->clear(); auto in1_iter = in1_accum.begin(); @@ -159,7 +159,7 @@ Status UnifyAccumulateProduct(const std::vector& in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector& in1, const std::vector& in2, std::vector* out) { +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out) { MS_EXCEPTION_IF_NULL(out); std::vector in1_accum; Status status = ShapeToAccumulateProduct(in1, &in1_accum); @@ -194,9 +194,9 @@ Status UnifyShape(const std::vector& in1, const std::vector& i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, - const std::vector& expand_accum_reverse, - std::vector* out_accum_reverse) { +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse) { MS_EXCEPTION_IF_NULL(out_accum_reverse); out_accum_reverse->clear(); auto in_riter = in_accum_reverse.rbegin(); @@ -236,7 +236,7 @@ Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector& in, const std::vector& expand, std::vector* out) { +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out) { MS_EXCEPTION_IF_NULL(out); std::vector in_accum_reverse; Status status = ShapeToAccumulateProductReverse(in, &in_accum_reverse); diff --git a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h index e83156500c..2ec21f3881 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/shape_util.h +++ b/mindspore/ccsrc/parallel/tensor_layout/shape_util.h @@ -39,7 +39,7 @@ namespace parallel { * shape_accum = [2, 2 * 8, 2 * 8 * 32] * */ -Status ShapeToAccumulateProduct(const std::vector& shape, std::vector* shape_accum); +Status ShapeToAccumulateProduct(const std::vector &shape, std::vector *shape_accum); /* * compute the accumulating product of all the values in shape from right to left, @@ -53,7 +53,7 @@ Status ShapeToAccumulateProduct(const std::vector& shape, std::vector& shape, std::vector* shape_accum); +Status ShapeToAccumulateProductReverse(const std::vector &shape, std::vector *shape_accum); /* * compute the original shape from the accumulating product shape_accum, @@ -68,7 +68,7 @@ Status ShapeToAccumulateProductReverse(const std::vector& shape, std::v * shape = [2, 8, 32] * */ -Status AccumulateProductToShape(const std::vector& shape_accum, std::vector* shape); +Status AccumulateProductToShape(const std::vector &shape_accum, std::vector *shape); /* * compute the original shape from the accumulating product shape_accum, @@ -83,7 +83,7 @@ Status AccumulateProductToShape(const std::vector& shape_accum, std::ve * shape = [2, 8, 32] * */ -Status AccumulateProductReverseToShape(const std::vector& shape_accum_reverse, std::vector* shape); +Status AccumulateProductReverseToShape(const std::vector &shape_accum_reverse, std::vector *shape); /* * given two accumulate product in1_accum and in2_accum, compute the union of in1_accum and in2_accum, @@ -101,8 +101,8 @@ Status AccumulateProductReverseToShape(const std::vector& shape_accum_r * in2_accum = [8, 16] * out_accum = [2, 4, 8, 16] */ -Status UnifyAccumulateProduct(const std::vector& in1_accum, const std::vector& in2_accum, - std::vector* out_accum); +Status UnifyAccumulateProduct(const std::vector &in1_accum, const std::vector &in2_accum, + std::vector *out_accum); /* * given two shape in1 = [din1_n-1, din1_n-2, ..., din1_0] and in2 = [din2_m-1, din2_m-2, ..., din2_m] @@ -117,7 +117,7 @@ Status UnifyAccumulateProduct(const std::vector& in1_accum, const std:: * in2 = [2, 16] * out = [2, 4, 4] */ -Status UnifyShape(const std::vector& in1, const std::vector& in2, std::vector* out); +Status UnifyShape(const std::vector &in1, const std::vector &in2, std::vector *out); /* * given two accumulate product in reverse order of in and expand, @@ -141,9 +141,9 @@ Status UnifyShape(const std::vector& in1, const std::vector& i * expand_accum_reverse = [2 * 4 * 8, 4 * 8, 8] * out_accum_reverse = [2 * 4 * 2 * 4 * 8, 4 * 2 * 4 * 8, 2 * 4 * 8, 4 * 8, 8] */ -Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, - const std::vector& expand_accum_reverse, - std::vector* out_accum_reverse); +Status ExpandAccumulateProduct(const std::vector &in_accum_reverse, + const std::vector &expand_accum_reverse, + std::vector *out_accum_reverse); /* * given a shape in = [din_n-1, din_n-2, ..., d_0], and the expand shape expand= [dexp_m-1, dexp_m-2, ..., dexp_0], @@ -165,7 +165,7 @@ Status ExpandAccumulateProduct(const std::vector& in_accum_reverse, * expand = [2, 4, 8] * out = [2, 4, 2, 4, 8] */ -Status ExpandShape(const std::vector& in, const std::vector& expand, std::vector* out); +Status ExpandShape(const std::vector &in, const std::vector &expand, std::vector *out); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h index 4a64ab472c..43286317c5 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_info.h @@ -32,9 +32,9 @@ using Shapes = std::vector; class TensorInfo { public: - TensorInfo(const TensorLayout& tensor_layout, Shape shape, Shape slice_shape) + TensorInfo(const TensorLayout &tensor_layout, Shape shape, Shape slice_shape) : tensor_layout_(tensor_layout), shape_(std::move(shape)), slice_shape_(std::move(slice_shape)) {} - explicit TensorInfo(const TensorLayout& tensor_layout) : tensor_layout_(tensor_layout) { + explicit TensorInfo(const TensorLayout &tensor_layout) : tensor_layout_(tensor_layout) { shape_ = tensor_layout.tensor_shape().array(); slice_shape_ = tensor_layout.slice_shape().array(); } @@ -44,7 +44,7 @@ class TensorInfo { TensorLayout tensor_layout() const { return tensor_layout_; } Shape slice_shape() const { return slice_shape_; } Shape shape() const { return shape_; } - void set_reduce_dim(const std::vector& dim) { reduce_dim_ = dim; } + void set_reduce_dim(const std::vector &dim) { reduce_dim_ = dim; } std::vector reduce_dim() const { return reduce_dim_; } private: diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc index 5fbd04431c..f3498065f2 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.cc @@ -45,8 +45,8 @@ std::string TensorLayout::OriginToString() const { return buffer.str(); } -Status TensorLayout::Init(const Arrangement& device_arrangement, const Map& tensor_map, - const Arrangement& tensor_shape) { +Status TensorLayout::Init(const Arrangement &device_arrangement, const Map &tensor_map, + const Arrangement &tensor_shape) { device_arrangement_origin_ = device_arrangement; tensor_map_origin_ = tensor_map; tensor_shape_origin_ = tensor_shape; @@ -64,8 +64,8 @@ Status TensorLayout::Init(const Arrangement& device_arrangement, const Map& tens } } -Status TensorLayout::InitFromVector(const std::vector& device_arrangement, - const std::vector& tensor_map, const std::vector& tensor_shape) { +Status TensorLayout::InitFromVector(const std::vector &device_arrangement, + const std::vector &tensor_map, const std::vector &tensor_shape) { if (device_arrangement_origin_.Init(device_arrangement) != SUCCESS) { return FAILED; } @@ -124,7 +124,7 @@ void TensorLayout::RemoveElementEqualToOneInDeviceArrangement() { if (idx != -1) { tensor_map_shape[static_cast(idx)] = -1; } - for (auto& value : tensor_map_shape) { + for (auto &value : tensor_map_shape) { if (value >= dev_num_left - 1 - static_cast(i)) { value--; } @@ -153,7 +153,7 @@ int32_t TensorLayout::GetSliceNumByTensorDimensionIndex(uint32_t idx) const { return device_arrangement_.GetDimByIdx(static_cast(GetSliceDeviceDimensionByTensorDimensionIndex(idx))); } -std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement& expanded_shape) const { +std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement &expanded_shape) const { std::shared_ptr expanded_arrangement_ptr = ComputeArrangementByExpandedShape(expanded_shape); if (expanded_arrangement_ptr == nullptr) { return nullptr; @@ -174,7 +174,7 @@ std::shared_ptr TensorLayout::ExpandTensorShape(const Arrangement& * => * out_device_arrangement = [8, 2, 2] */ -std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement& tensor_shape) const { +std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const { std::shared_ptr> expand_list_ptr = tensor_shape_.GetExpandShapeList(tensor_shape); if (expand_list_ptr == nullptr) { return nullptr; @@ -204,7 +204,7 @@ std::shared_ptr TensorLayout::ComputeArrangementByExpandedShape(con * out_tensor_map = [1, -1, 0, -1], */ std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement& expanded_shape) const { + const Arrangement &expanded_shape) const { std::shared_ptr, Arrangement>> expand_list_pair_ptr = tensor_shape_.GetExpandShapeListPair(expanded_shape); if (expand_list_pair_ptr == nullptr) { @@ -259,7 +259,7 @@ std::shared_ptr TensorLayout::ExpandTensorShapeWithoutExtendDevice * out_tensor_map = [0, 2, 1], * out_tensor_shape = [512, 4, 256] */ -std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement& expanded_arrangement) const { +std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const { std::shared_ptr, Arrangement>> expand_list_pair_ptr = device_arrangement_.GetExpandShapeListPair(expanded_arrangement); if (expand_list_pair_ptr == nullptr) { @@ -287,7 +287,7 @@ std::shared_ptr TensorLayout::ExpandDeviceArrangement(const Arrang return std::make_shared(tensor_layout_new); } -bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement& expand_shape) const { +bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement &expand_shape) const { std::vector in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { @@ -296,7 +296,7 @@ bool TensorLayout::TensorShapeCanBeExpanded(const Arrangement& expand_shape) con return (in_expand_shape_shape == tensor_shape_.array()); } -std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement& expand_shape) const { +std::shared_ptr TensorLayout::ComputeExpandedTensorShape(const Arrangement &expand_shape) const { std::vector in_expand_shape_shape; Status status = ExpandShape(tensor_shape_.array(), expand_shape.array(), &in_expand_shape_shape); if (status != Status::SUCCESS) { @@ -345,7 +345,7 @@ Status TensorLayout::UpdateTensorMap(uint32_t index, int32_t value) { return Status::SUCCESS; } -bool TensorLayout::operator==(const TensorLayout& t1) const { +bool TensorLayout::operator==(const TensorLayout &t1) const { return (IsSameDeviceArrangement(t1) && IsSameTensorMap(t1) && IsSameTensorShape(t1)); } diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h index e6ddc2a708..f51ed4e3e0 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_layout.h @@ -37,9 +37,9 @@ class TensorLayout { std::string ToString() const; std::string StandardToString() const; std::string OriginToString() const; - Status Init(const Arrangement& device_arrangement, const Map& tensor_map, const Arrangement& tensor_shape); - Status InitFromVector(const std::vector& device_arrangement, const std::vector& tensor_map, - const std::vector& tensor_shape); + Status Init(const Arrangement &device_arrangement, const Map &tensor_map, const Arrangement &tensor_shape); + Status InitFromVector(const std::vector &device_arrangement, const std::vector &tensor_map, + const std::vector &tensor_shape); Arrangement device_arrangement() const { return device_arrangement_; } @@ -49,25 +49,25 @@ class TensorLayout { Map origin_tensor_map() const { return tensor_map_origin_; } - std::shared_ptr ExpandTensorShape(const Arrangement& expanded_shape) const; + std::shared_ptr ExpandTensorShape(const Arrangement &expanded_shape) const; - std::shared_ptr ExpandDeviceArrangement(const Arrangement& expanded_arrangement) const; + std::shared_ptr ExpandDeviceArrangement(const Arrangement &expanded_arrangement) const; - bool IsSameTensorShape(const TensorLayout& tensor_layout) const { + bool IsSameTensorShape(const TensorLayout &tensor_layout) const { return (tensor_shape_ == tensor_layout.tensor_shape()); } - bool IsSameDeviceArrangement(const TensorLayout& tensor_layout) const { + bool IsSameDeviceArrangement(const TensorLayout &tensor_layout) const { return (device_arrangement_ == tensor_layout.device_arrangement()); } - bool IsSameTensorMap(const TensorLayout& tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } + bool IsSameTensorMap(const TensorLayout &tensor_layout) const { return (tensor_map_ == tensor_layout.tensor_map()); } - bool operator==(const TensorLayout& t1) const; + bool operator==(const TensorLayout &t1) const; - bool TensorShapeCanBeExpanded(const Arrangement& expanded_shape) const; + bool TensorShapeCanBeExpanded(const Arrangement &expanded_shape) const; - std::shared_ptr ComputeExpandedTensorShape(const Arrangement& expand_shape) const; + std::shared_ptr ComputeExpandedTensorShape(const Arrangement &expand_shape) const; Arrangement slice_shape() const; @@ -77,8 +77,8 @@ class TensorLayout { private: std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( - const Arrangement& expanded_shape) const; - std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement& tensor_shape) const; + const Arrangement &expanded_shape) const; + std::shared_ptr ComputeArrangementByExpandedShape(const Arrangement &tensor_shape) const; bool IsValidTensorLayout() const; void RemoveElementEqualToOneInDeviceArrangement(); int32_t GetSliceDeviceDimensionByTensorDimensionIndex(uint32_t idx) const; diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc index 460cd9d1bd..7824c21f3d 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.cc @@ -24,7 +24,7 @@ namespace mindspore { namespace parallel { -Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list) { +Status TensorRedistribution::Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list) { from_origin_ = from; to_origin_ = to; if (from_origin_.tensor_shape().size() != to_origin_.tensor_shape().size()) { @@ -87,9 +87,9 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL std::make_pair(operator_vector, output_info_vector)); } -Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, - OperatorVector* const operator_vector, - OutPutInfoVector* const output_info_vector) { +Status TensorRedistribution::InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, + OutPutInfoVector *const output_info_vector) { MS_EXCEPTION_IF_NULL(operator_vector); MS_EXCEPTION_IF_NULL(output_info_vector); ConstructOperator constructor; @@ -144,7 +144,7 @@ Status TensorRedistribution::ComputeCost() { return Status::FAILED; } // Compute redistribution communication cost and computation cost - for (auto& op_cost : operator_list_) { + for (auto &op_cost : operator_list_) { OperatorR op = op_cost.first; Shape slice_shape = op_cost.second; double prod = diff --git a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h index 71d4a02701..e7800909c5 100644 --- a/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h +++ b/mindspore/ccsrc/parallel/tensor_layout/tensor_redistribution.h @@ -46,7 +46,7 @@ class TensorRedistribution { memory_cost_(0.0), construct_op_flag_(construct_op_flag), keep_reshape_(keep_reshape) {} - Status Init(const TensorLayout& from, const TensorLayout& to, const RankList& dev_list); + Status Init(const TensorLayout &from, const TensorLayout &to, const RankList &dev_list); ~TensorRedistribution() = default; RedistributionOpListPtr InferTensorRedistributionOperatorList(bool is_cost_model = false); OperatorList operator_list() const { return operator_list_; } @@ -59,8 +59,8 @@ class TensorRedistribution { double memory_cost() const { return memory_cost_; } private: - Status InferReshape(const TensorLayout& from_layout, const TensorLayout& to_layout, - OperatorVector* const operator_vector, OutPutInfoVector* const output_info_vector); + Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout, + OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector); TensorLayout from_origin_; TensorLayout to_origin_; diff --git a/mindspore/ccsrc/pipeline/action.cc b/mindspore/ccsrc/pipeline/action.cc index 3e0f8804e7..e8723e66a4 100644 --- a/mindspore/ccsrc/pipeline/action.cc +++ b/mindspore/ccsrc/pipeline/action.cc @@ -41,8 +41,8 @@ using CompileGraphs = compile::CompileGraphs; using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec, bool clear) { +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear) { MS_LOG(DEBUG) << "AbstractAnalyze start"; auto engine = res->engine(); MS_EXCEPTION_IF_NULL(engine); @@ -50,9 +50,9 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraph auto manager = res->manager(); MS_EXCEPTION_IF_NULL(manager); engine->Clear(); - for (auto& node : manager->all_nodes()) { + for (auto &node : manager->all_nodes()) { MS_EXCEPTION_IF_NULL(node); - const AbstractBasePtr& prev_inferred = node->abstract(); + const AbstractBasePtr &prev_inferred = node->abstract(); // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { node->set_abstract(nullptr); @@ -65,8 +65,8 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraph return ret; } -FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AnalysisContextPtr& context) { +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context) { MS_LOG(DEBUG) << "ProgramSpecialize start"; abstract::ProgramSpecializer spc(res->engine()); FuncGraphPtr result = spc.Run(func_graph, context); @@ -77,8 +77,8 @@ FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_ return result; } -FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec) { +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec) { MS_LOG(DEBUG) << "Renormalize start"; #ifdef ENABLE_PROFILE double t1 = GetTime(); @@ -98,7 +98,7 @@ FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, return ret; } -bool ParseAction(const ResourcePtr& res) { +bool ParseAction(const ResourcePtr &res) { if (!res->input()) { MS_LOG(EXCEPTION) << "Parse error"; } @@ -129,11 +129,11 @@ bool ParseAction(const ResourcePtr& res) { // This step do this optimize: graph1(x){xx(fv1),xxx(fv2)}, graph2(x){xxx(fv3),xxx(fv4)}-> // graph1(x){base_graph(x, fv1, fv2)}, graph1(x){base_graph(x, fv3, fv4)}, base_graph(x, fv...){xxx,xxx} // all obj_map's graph shared base_graph -bool CombineLikeGraphs(const ResourcePtr&) { - auto& obj_map = parse::data_converter::GetObjGraphs(); +bool CombineLikeGraphs(const ResourcePtr &) { + auto &obj_map = parse::data_converter::GetObjGraphs(); for (auto it : obj_map) { - auto& graphs = it.second; + auto &graphs = it.second; MS_LOG(DEBUG) << "Start combine like graph:" << it.first << ", size:" << graphs.size(); auto fg = graphs[0]; FuncGraphPtrList func_graphs = {fg}; @@ -147,7 +147,7 @@ bool CombineLikeGraphs(const ResourcePtr&) { continue; } auto mng = Manage(base_graph, false); - for (auto& fv : fg->paramter_obj_nodes()) { + for (auto &fv : fg->paramter_obj_nodes()) { TraceManager::DebugTrace(std::make_shared(fv->debug_info())); auto param = base_graph->add_parameter(); TraceManager::EndTrace(); @@ -156,11 +156,11 @@ bool CombineLikeGraphs(const ResourcePtr&) { } MS_LOG(DEBUG) << "Fg0 paramter_obj_nodes size :" << fg->paramter_obj_nodes().size(); - for (auto& g : graphs) { + for (auto &g : graphs) { auto fvs = g->paramter_obj_nodes(); std::vector new_node_inputs; new_node_inputs.push_back(NewValueNode(base_graph)); - for (auto& p : g->parameters()) { + for (auto &p : g->parameters()) { AnfNodePtr para_after_cast = parse::GetMixedPrecisionCastHelp(g, p); new_node_inputs.push_back(para_after_cast); } @@ -174,7 +174,7 @@ bool CombineLikeGraphs(const ResourcePtr&) { return true; } -bool SymbolResolveAction(const ResourcePtr& res) { +bool SymbolResolveAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "SymbolResolve error, manager is null"; } @@ -195,7 +195,7 @@ bool SymbolResolveAction(const ResourcePtr& res) { return succ; } -bool InferenceOptPrepareAction(const ResourcePtr& res) { +bool InferenceOptPrepareAction(const ResourcePtr &res) { if (res->manager() == nullptr) { MS_LOG(EXCEPTION) << "InferenceOptPrepare error, manager is null."; } @@ -205,7 +205,7 @@ bool InferenceOptPrepareAction(const ResourcePtr& res) { return InferenceOptPreparePass(res); } -bool AbstractSpecializeAction(const ResourcePtr& res) { +bool AbstractSpecializeAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "AbstractSpecialize error"; } @@ -215,7 +215,7 @@ bool AbstractSpecializeAction(const ResourcePtr& res) { // suppose that there is not KeywordArgument for the top graph // get the hyper parameter - for (const auto& param : func_graph->parameters()) { + for (const auto ¶m : func_graph->parameters()) { auto param_node = std::static_pointer_cast(param); if (param_node->has_default()) { AbstractBasePtr ptr = @@ -236,8 +236,8 @@ bool AbstractSpecializeAction(const ResourcePtr& res) { return true; } -bool OptimizeAction(const ResourcePtr& res, const std::vector& passes) { - for (auto& pass : passes) { +bool OptimizeAction(const ResourcePtr &res, const std::vector &passes) { + for (auto &pass : passes) { WITH(MsProfile::GetProfile()->Step(pass.first))[&pass, &res]() { MS_LOG(DEBUG) << "Pass " << pass.first << " start ..."; auto result = pass.second(res); @@ -251,11 +251,11 @@ bool OptimizeAction(const ResourcePtr& res, const std::vector& passes) return true; } -bool GeOptimizeAction(const ResourcePtr& res) { return OptimizeAction(res, kGePasses); } +bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); } -bool VmOptimizeAction(const ResourcePtr& res) { return OptimizeAction(res, kVmPasses); } +bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); } -bool TaskEmitAction(const ResourcePtr& res) { +bool TaskEmitAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "TaskEmit args error"; } @@ -271,7 +271,7 @@ bool TaskEmitAction(const ResourcePtr& res) { return true; } -bool ExecuteAction(const ResourcePtr& res) { +bool ExecuteAction(const ResourcePtr &res) { if (res->results().count(kOutput) == 0 || !res->results()[kOutput].is()) { MS_LOG(EXCEPTION) << "Execute args error"; } @@ -291,11 +291,11 @@ bool ExecuteAction(const ResourcePtr& res) { // that will result in a syncronization error due to different executing order. // Here we temporarily avoid the problem by skipping valuenode merging used by parallel related primitive, // the final solution will be proposed later as a parallel feature. -bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& res) { - auto& node_users = res->manager()->node_users(); - auto& users = node_users[value_node]; +bool KeepValueNodeDuplication(const AnfNodePtr &value_node, const ResourcePtr &res) { + auto &node_users = res->manager()->node_users(); + auto &users = node_users[value_node]; auto used_by_keep_value_prim = - std::any_of(users.begin(), users.end(), [](const std::pair& user) -> bool { + std::any_of(users.begin(), users.end(), [](const std::pair &user) -> bool { MS_EXCEPTION_IF_NULL(user.first); auto cnode = user.first->cast(); if (cnode == nullptr) { @@ -312,7 +312,7 @@ bool KeepValueNodeDuplication(const AnfNodePtr& value_node, const ResourcePtr& r return used_by_keep_value_prim; } -bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { +bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "Remove value node duplications error."; } @@ -322,7 +322,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { auto value_nodes = manager->valuenodes()[func_graph]; HashCache hash_cache; HashValue hashes; - for (const auto& value_pair : value_nodes) { + for (const auto &value_pair : value_nodes) { if (KeepValueNodeDuplication(value_pair.first, res)) { continue; } @@ -331,7 +331,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr& res) { return true; } -bool ValidateAction(const ResourcePtr& res) { return ValidatePass(res); } +bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } static std::vector CommonPipeline() { std::vector actions; diff --git a/mindspore/ccsrc/pipeline/action.h b/mindspore/ccsrc/pipeline/action.h index 159e494a96..8a651c0038 100644 --- a/mindspore/ccsrc/pipeline/action.h +++ b/mindspore/ccsrc/pipeline/action.h @@ -30,22 +30,22 @@ extern const char kMsConvert[]; namespace pipeline { using ActionItem = std::pair>; -bool ParseAction(const ResourcePtr& res); -bool SymbolResolveAction(const ResourcePtr& res); -bool AbstractSpecializeAction(const ResourcePtr& res); -bool GeOptimizeAction(const ResourcePtr& res); -bool VmOptimizeAction(const ResourcePtr& res); -bool TaskEmitAction(const ResourcePtr& res); -bool ExecuteAction(const ResourcePtr& res); +bool ParseAction(const ResourcePtr &res); +bool SymbolResolveAction(const ResourcePtr &res); +bool AbstractSpecializeAction(const ResourcePtr &res); +bool GeOptimizeAction(const ResourcePtr &res); +bool VmOptimizeAction(const ResourcePtr &res); +bool TaskEmitAction(const ResourcePtr &res); +bool ExecuteAction(const ResourcePtr &res); std::vector GePipeline(); std::vector VmPipeline(); -abstract::AnalysisResult AbstractAnalyze(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec, bool clear = false); -FuncGraphPtr ProgramSpecialize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AnalysisContextPtr& context); -FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph, - const abstract::AbstractBasePtrList& args_spec); +abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec, bool clear = false); +FuncGraphPtr ProgramSpecialize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AnalysisContextPtr &context); +FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, + const abstract::AbstractBasePtrList &args_spec); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/base.h b/mindspore/ccsrc/pipeline/base.h index 30524e84f6..8ca153f45b 100644 --- a/mindspore/ccsrc/pipeline/base.h +++ b/mindspore/ccsrc/pipeline/base.h @@ -37,7 +37,7 @@ struct ExecutorInfo { using ExecutorInfoPtr = std::shared_ptr; -inline std::string GetPhasePrefix(const std::string& phase) { +inline std::string GetPhasePrefix(const std::string &phase) { auto pos = phase.find('.'); if (pos == std::string::npos) { MS_LOG(EXCEPTION) << "Phase has no . for prefix" << phase; @@ -45,7 +45,7 @@ inline std::string GetPhasePrefix(const std::string& phase) { return phase.substr(0, pos); } -inline std::string GetFilePathName(const std::string& file_name) { +inline std::string GetFilePathName(const std::string &file_name) { std::ostringstream oss; auto ms_context = MsContext::GetInstance(); if (ms_context == nullptr) { diff --git a/mindspore/ccsrc/pipeline/init.cc b/mindspore/ccsrc/pipeline/init.cc index b709199c87..86e6d436b7 100644 --- a/mindspore/ccsrc/pipeline/init.cc +++ b/mindspore/ccsrc/pipeline/init.cc @@ -53,10 +53,10 @@ PYBIND11_MODULE(_c_expression, m) { (void)py::class_>(*m, "MetaFuncGraph_") .def_readonly(mindspore::PYTHON_METAFUNCGRAPH_FLAG, &mindspore::MetaFuncGraph::parse_info_) - .def(py::init()); + .def(py::init()); auto fns = mindspore::PybindDefineRegister::AllFuncs(); - for (auto& item : fns) { + for (auto &item : fns) { item.second(&m); } @@ -288,7 +288,7 @@ PYBIND11_MODULE(_c_expression, m) { }}); (void)py::class_>(m, "EventWriter_") - .def(py::init()) + .def(py::init()) .def("GetFileName", &EventWriter::GetFileName, "Get the file name.") .def("Open", &EventWriter::Open, "Open the write file.") .def("Write", &EventWriter::Write, "Write the serialize event.") diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.cc b/mindspore/ccsrc/pipeline/parse/data_converter.cc index d25a202afc..861fc0eda8 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/parse/data_converter.cc @@ -38,7 +38,7 @@ using Tensor = mindspore::tensor::Tensor; using TensorPtr = mindspore::tensor::TensorPtr; namespace { -bool ConvertTuple(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python tuple"; py::tuple tuple = obj.cast(); std::vector value_list; @@ -55,7 +55,7 @@ bool ConvertTuple(const py::object& obj, ValuePtr* const data, bool use_signatur return true; } -bool ConvertList(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting python list"; py::list list = obj.cast(); @@ -72,7 +72,7 @@ bool ConvertList(const py::object& obj, ValuePtr* const data, bool use_signature return true; } -bool ConvertCellList(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertCellList(const py::object &obj, ValuePtr *const data, bool use_signature) { MS_LOG(DEBUG) << "Converting cell list"; py::sequence list = obj; std::vector value_list; @@ -88,7 +88,7 @@ bool ConvertCellList(const py::object& obj, ValuePtr* const data, bool use_signa return true; } -bool ConvertDict(const py::object& obj, ValuePtr* data, bool use_signature) { +bool ConvertDict(const py::object &obj, ValuePtr *data, bool use_signature) { MS_LOG(DEBUG) << "Converting python dict"; py::dict dict_values = obj.cast(); @@ -109,14 +109,14 @@ bool ConvertDict(const py::object& obj, ValuePtr* data, bool use_signature) { return true; } -void ConvertNameSpace(const py::object& obj, ValuePtr* const data) { +void ConvertNameSpace(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting python module"; py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object module_namespace = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MODULE_NAMESPACE, obj); *data = std::make_shared(RESOLVE_NAMESPACE_NAME_MODULE, py::cast(module_namespace)); } -void ConvertDataClass(py::object obj, ValuePtr* const data) { +void ConvertDataClass(py::object obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting dataclass"; // Maybe the obj is dataclass define auto desc = py::cast(python_adapter::CallPyObjMethod(obj, PYTHON_GET_OBJ_DESC, obj)); @@ -124,7 +124,7 @@ void ConvertDataClass(py::object obj, ValuePtr* const data) { *data = std::make_shared(obj, std::string(desc.begin() + 1, desc.end() - 1)); } -bool ConvertPrimitive(py::object obj, ValuePtr* const data, bool use_signature = false) { +bool ConvertPrimitive(py::object obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting primitive object"; // need check the primitive is class type or instance @@ -155,7 +155,7 @@ bool ConvertPrimitive(py::object obj, ValuePtr* const data, bool use_signature = return true; } -bool ConvertMetaFuncGraph(const py::object& obj, ValuePtr* const data, bool use_signature = false) { +bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_signature = false) { MS_LOG(DEBUG) << "Converting MetaFuncGraph object"; auto meta = obj.cast(); if (meta == nullptr) { @@ -170,7 +170,7 @@ bool ConvertMetaFuncGraph(const py::object& obj, ValuePtr* const data, bool use_ return true; } -bool ConvertDataType(const py::object& obj, ValuePtr* const data) { +bool ConvertDataType(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting type object"; auto typeptr = obj.cast(); if (typeptr == nullptr) { @@ -181,7 +181,7 @@ bool ConvertDataType(const py::object& obj, ValuePtr* const data) { return true; } -bool ConvertTensor(const py::object& obj, ValuePtr* const data) { +bool ConvertTensor(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting tensor object"; auto m_tensor = obj.cast(); @@ -193,7 +193,7 @@ bool ConvertTensor(const py::object& obj, ValuePtr* const data) { return true; } -bool ConvertOtherObj(py::object obj, ValuePtr* const data) { +bool ConvertOtherObj(py::object obj, ValuePtr *const data) { auto obj_type = data_converter::GetObjType(obj); MS_LOG(DEBUG) << "Converting the object(" << ((std::string)py::str(obj)) << ") detail type: " << obj_type << " "; if (obj_type == RESOLVE_TYPE_CLASS_TYPE) { @@ -244,7 +244,7 @@ bool ConvertOtherObj(py::object obj, ValuePtr* const data) { } } // namespace -bool ConvertData(const py::object& obj, ValuePtr* const data, bool use_signature) { +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature) { // check parameter valid if (data == nullptr) { MS_LOG(ERROR) << "Data is null pointer"; @@ -295,7 +295,7 @@ bool ConvertData(const py::object& obj, ValuePtr* const data, bool use_signature } // convert data to graph -FuncGraphPtr ConvertToFuncGraph(const py::object& obj, const std::string& python_mod_get_parse_method) { +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, const std::string &python_mod_get_parse_method) { std::vector results = data_converter::GetObjKey(obj); std::string obj_id = results[0] + python_mod_get_parse_method; std::string obj_key = results[1]; @@ -331,25 +331,25 @@ static std::unordered_map object_map_ = std::unordered_map> object_graphs_map_ = std::unordered_map>(); -void SetObjGraphValue(const std::string& obj_key, const FuncGraphPtr& data) { +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data) { object_graphs_map_[obj_key].push_back(data); MS_LOG(DEBUG) << "Set func graph size:" << object_graphs_map_.size(); } -const std::unordered_map>& GetObjGraphs() { +const std::unordered_map> &GetObjGraphs() { MS_LOG(DEBUG) << "Obj size:" << object_graphs_map_.size(); return object_graphs_map_; } -void CacheObjectValue(const std::string& obj_key, const Any& data) { object_map_[obj_key] = data; } -bool GetObjectValue(const std::string& obj_key, Any* const data) { +void CacheObjectValue(const std::string &obj_key, const Any &data) { object_map_[obj_key] = data; } +bool GetObjectValue(const std::string &obj_key, Any *const data) { if (object_map_.count(obj_key)) { *data = object_map_[obj_key]; return true; } return false; } -std::vector GetObjKey(const py::object& obj) { +std::vector GetObjKey(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::tuple obj_tuple = python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_KEY, obj); if (obj_tuple.size() != 2) { @@ -359,7 +359,7 @@ std::vector GetObjKey(const py::object& obj) { } // get obj detail type -ResolveTypeDef GetObjType(const py::object& obj) { +ResolveTypeDef GetObjType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto obj_type = ResolveTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_RESOLVE_GET_OBJ_TYPE, obj).cast()); @@ -367,7 +367,7 @@ ResolveTypeDef GetObjType(const py::object& obj) { } // get class instance detail type -ClassInstanceTypeDef GetClassInstanceType(const py::object& obj) { +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); auto class_type = ClassInstanceTypeDef(python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_CLASS_INSTANCE_TYPE, obj).cast()); @@ -375,14 +375,14 @@ ClassInstanceTypeDef GetClassInstanceType(const py::object& obj) { } // check the object is Cell Instance -bool IsCellInstance(const py::object& obj) { +bool IsCellInstance(const py::object &obj) { auto class_type = GetClassInstanceType(obj); bool isCell = (class_type == CLASS_INSTANCE_TYPE_CELL); return isCell; } // create the python class instance -py::object CreatePythonObject(const py::object& type, const py::tuple& params) { +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::object obj; if (params.size() == 0) { @@ -395,7 +395,7 @@ py::object CreatePythonObject(const py::object& type, const py::tuple& params) { // Generate an appropriate name and set to graph debuginfo // character <> can not used in the dot file, so change to another symbol -void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name) { +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name) { MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph->debug_info()); // set detail name info of function @@ -412,7 +412,7 @@ void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name) func_graph->debug_info()->set_full_name(oss.str()); } -ValuePtr PyDataToValue(const py::object& obj) { +ValuePtr PyDataToValue(const py::object &obj) { py::object to_convert = obj; if (py::hasattr(obj, "__parameter__")) { to_convert = py::cast(python_adapter::GetPyObjAttr(obj, "default_input")); @@ -431,7 +431,7 @@ void ClearObjectCache() { static std::unordered_map g_dataClassToClass = {}; // parse dataclass to mindspore Class type -ClassPtr ParseDataClass(const py::object& cls_obj) { +ClassPtr ParseDataClass(const py::object &cls_obj) { std::string cls_name = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__name__")); std::string cls_module = py::cast(python_adapter::GetPyObjAttr(cls_obj, "__module__")); std::string cls = cls_module + "." + cls_name; @@ -443,7 +443,7 @@ ClassPtr ParseDataClass(const py::object& cls_obj) { py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); ClassAttrVector attributes; py::dict names = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_ATTRS, cls_obj); - for (auto& item : names) { + for (auto &item : names) { TypePtr type_value = item.second.cast(); MS_EXCEPTION_IF_NULL(type_value); MS_LOG(DEBUG) << "(Name: " << py::cast(item.first) << ", type: " << type_value->ToString() << ")"; @@ -452,7 +452,7 @@ ClassPtr ParseDataClass(const py::object& cls_obj) { std::unordered_map methods_map; py::dict methods = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_DATACLASS_METHODS, cls_obj); - for (auto& item : methods) { + for (auto &item : methods) { std::string fun_name = item.first.cast(); py::object obj = py::cast(item.second); std::shared_ptr method_obj = std::make_shared(obj, fun_name); diff --git a/mindspore/ccsrc/pipeline/parse/data_converter.h b/mindspore/ccsrc/pipeline/parse/data_converter.h index 658360bcee..a8918fa60c 100644 --- a/mindspore/ccsrc/pipeline/parse/data_converter.h +++ b/mindspore/ccsrc/pipeline/parse/data_converter.h @@ -32,25 +32,25 @@ namespace mindspore { namespace parse { // data convert for parse namespace data_converter { -void CacheObjectValue(const std::string& obj_key, const Any& data); -bool GetObjectValue(const std::string& obj_key, Any* const data); +void CacheObjectValue(const std::string &obj_key, const Any &data); +bool GetObjectValue(const std::string &obj_key, Any *const data); -void SetObjGraphValue(const std::string& obj_key, const FuncGraphPtr& data); +void SetObjGraphValue(const std::string &obj_key, const FuncGraphPtr &data); -const std::unordered_map>& GetObjGraphs(); +const std::unordered_map> &GetObjGraphs(); -std::vector GetObjKey(const py::object& obj); -ResolveTypeDef GetObjType(const py::object& obj); -ClassInstanceTypeDef GetClassInstanceType(const py::object& obj); +std::vector GetObjKey(const py::object &obj); +ResolveTypeDef GetObjType(const py::object &obj); +ClassInstanceTypeDef GetClassInstanceType(const py::object &obj); -bool IsCellInstance(const py::object& obj); -py::object CreatePythonObject(const py::object& type, const py::tuple& params); -void MakeProperNameToFuncGraph(const FuncGraphPtr& func_graph, std::string name); -ValuePtr PyDataToValue(const py::object& obj); +bool IsCellInstance(const py::object &obj); +py::object CreatePythonObject(const py::object &type, const py::tuple ¶ms); +void MakeProperNameToFuncGraph(const FuncGraphPtr &func_graph, std::string name); +ValuePtr PyDataToValue(const py::object &obj); void ClearObjectCache(); } // namespace data_converter -ClassPtr ParseDataClass(const py::object& cls_obj); +ClassPtr ParseDataClass(const py::object &cls_obj); void CleanDataClassToClassMap(); diff --git a/mindspore/ccsrc/pipeline/parse/function_block.cc b/mindspore/ccsrc/pipeline/parse/function_block.cc index 423e76c1d8..156f727b9e 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.cc +++ b/mindspore/ccsrc/pipeline/parse/function_block.cc @@ -28,21 +28,21 @@ namespace mindspore { namespace parse { -FunctionBlock::FunctionBlock(const Parser& parser) : parser_(parser) { +FunctionBlock::FunctionBlock(const Parser &parser) : parser_(parser) { func_graph_ = std::make_shared(); matured_ = false; } -void FunctionBlock::AddPrevBlock(const FunctionBlockPtr& block) { prev_blocks_.push_back(block.get()); } +void FunctionBlock::AddPrevBlock(const FunctionBlockPtr &block) { prev_blocks_.push_back(block.get()); } // write variable records the variable name to corresponding node -void FunctionBlock::WriteVariable(const std::string& var_name, const AnfNodePtr& node) { +void FunctionBlock::WriteVariable(const std::string &var_name, const AnfNodePtr &node) { MS_LOG(DEBUG) << "" << func_graph_->ToString() << " write var " << var_name << " with node " << node->DebugString(); vars_[var_name] = node; } // read variable from predecessors -AnfNodePtr FunctionBlock::ReadVariable(const std::string& var) { +AnfNodePtr FunctionBlock::ReadVariable(const std::string &var) { // get var node if it is found if (vars_.count(var)) { AnfNodePtr node = vars_[var]; @@ -82,7 +82,7 @@ AnfNodePtr FunctionBlock::ReadVariable(const std::string& var) { } // Resolve Ast operator node -AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object& op) { +AnfNodePtr FunctionBlock::MakeResolveAstOp(const py::object &op) { auto ast = parser_.ast(); MS_EXCEPTION_IF_NULL(ast); TraceGuard trace_guard(parser_.GetLocation(op)); @@ -105,7 +105,7 @@ AnfNodePtr FunctionBlock::MakeResolveClassMember(std::string attr) { } // Make a resolve node for symbol string -AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string& value) { +AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string &value) { if (value.compare(0, strlen("self."), "self.") == 0) { auto start = value.find_first_of('.') + 1; if (start >= value.size()) { @@ -122,14 +122,14 @@ AnfNodePtr FunctionBlock::MakeResolveSymbol(const std::string& value) { return MakeResolve(name_space, symbol); } -AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string& value) { +AnfNodePtr FunctionBlock::MakeResolveOperation(const std::string &value) { py::tuple namespace_var = parser_.ast()->CallParserObjMethod(PYTHON_PARSE_GET_OPERATION_NAMESPACE_SYMBOL, value); NameSpacePtr name_space = std::make_shared(RESOLVE_NAMESPACE_NAME_COMMON_OPS, namespace_var[0]); SymbolPtr symbol = std::make_shared(namespace_var[1].cast()); return MakeResolve(name_space, symbol); } -AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr& name_space, const SymbolPtr& resolve_symbol) { +AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr &name_space, const SymbolPtr &resolve_symbol) { MS_LOG(DEBUG) << "MakeResolve for " << ((std::string)py::str(name_space->obj())) << " , " << ((std::string)resolve_symbol->symbol()); ValueNodePtr module_node = NewValueNode(name_space); @@ -139,10 +139,10 @@ AnfNodePtr FunctionBlock::MakeResolve(const NameSpacePtr& name_space, const Symb } // add input for the block's phi parameter -void FunctionBlock::SetPhiArgument(const ParameterPtr& phi) { +void FunctionBlock::SetPhiArgument(const ParameterPtr &phi) { std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " set phi " << phi->ToString() << " for var " << var; - for (auto& pred : prev_blocks_) { + for (auto &pred : prev_blocks_) { MS_EXCEPTION_IF_NULL(pred); MS_LOG(DEBUG) << "graph " << func_graph_->ToString() << " pred_blocks_ " << pred->func_graph_->ToString(); AnfNodePtr arg_node = pred->ReadVariable(var); @@ -161,9 +161,9 @@ void FunctionBlock::SetPhiArgument(const ParameterPtr& phi) { } } -AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string& var, const ParameterPtr& phi) { +AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string &var, const ParameterPtr &phi) { AnfNodePtr arg_node = nullptr; - for (auto& prev : prev_blocks_) { + for (auto &prev : prev_blocks_) { MS_EXCEPTION_IF_NULL(prev); AnfNodePtr temp_node = prev->ReadVariable(var); MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " phi " << phi->ToString() << " for var " << var @@ -204,7 +204,7 @@ AnfNodePtr FunctionBlock::SearchReplaceNode(const std::string& var, const Parame // 2. it's costly to iterate the graph to replace the phi for each phi. // Args : // phi : This parameter node is functioning as a phi node. -void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { +void FunctionBlock::CollectRemovablePhi(const ParameterPtr &phi) { MS_EXCEPTION_IF_NULL(phi); std::string var = phi_nodes_[phi]; MS_LOG(DEBUG) << "check phi " << phi->ToString() << " for " << var << " in graph " << func_graph_->ToString(); @@ -221,15 +221,15 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { removable_phis_[phi] = arg_node; // The following equal to statement "The φ-function defining v1, which now reads φ(v2, v1), is optimized // recursively". check if phi1 is assigned with this phi before, then phi1 can be replaced with arg_node. - for (auto& prev : prev_blocks_) { + for (auto &prev : prev_blocks_) { MS_EXCEPTION_IF_NULL(prev); if (!prev->matured_) { continue; } - for (auto& phi_iter : prev->removable_phis_) { + for (auto &phi_iter : prev->removable_phis_) { MS_EXCEPTION_IF_NULL(phi_iter.second); if (phi_iter.second->isa()) { - const auto& param = phi_iter.second->cast(); + const auto ¶m = phi_iter.second->cast(); if (param == phi) { MS_LOG(DEBUG) << "graph " << prev->func_graph_->ToString() << " var " << phi_iter.first->DebugString() << " can be replaced from " << param->DebugString() << " with " << arg_node->DebugString(); @@ -243,8 +243,8 @@ void FunctionBlock::CollectRemovablePhi(const ParameterPtr& phi) { // A block should be marked matured if its predecessor blocks have been processed void FunctionBlock::Mature() { - const auto& graphParamVec = func_graph_->parameters(); - for (auto& paramItr : graphParamVec) { + const auto &graphParamVec = func_graph_->parameters(); + for (auto ¶mItr : graphParamVec) { MS_EXCEPTION_IF_NULL(paramItr); ParameterPtr param = paramItr->cast(); if (phi_nodes_.find(param) != phi_nodes_.cend()) { @@ -255,7 +255,7 @@ void FunctionBlock::Mature() { } // Force the conditIon node to bool using bool operation -CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr& cond) { +CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr &cond) { TraceManager::DebugTrace(std::make_shared(cond->debug_info())); CNodePtr op_apply_node = func_graph()->NewCNode({MakeResolveOperation(NAMED_PRIMITIVE_BOOL), cond}); TraceManager::EndTrace(); @@ -263,7 +263,7 @@ CNodePtr FunctionBlock::ForceToBoolNode(const AnfNodePtr& cond) { } // Perform a jump from this block to target block -void FunctionBlock::Jump(const FunctionBlockPtr& target_block, AnfNodePtr node) { +void FunctionBlock::Jump(const FunctionBlockPtr &target_block, AnfNodePtr node) { if (func_graph()->get_return() != nullptr) { MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); @@ -283,8 +283,8 @@ void FunctionBlock::Jump(const FunctionBlockPtr& target_block, AnfNodePtr node) // Perform a conditional jump using switch operation. // The first CNode select graph with condition, and than execute this graph -void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& true_block, - const FunctionBlockPtr& false_block) { +void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &true_block, + const FunctionBlockPtr &false_block) { if (func_graph()->get_return() != nullptr) { MS_LOG(EXCEPTION) << "Failure: have return node! NodeInfo: " << trace::GetDebugInfo(func_graph()->get_return()->debug_info()); @@ -297,15 +297,15 @@ void FunctionBlock::ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& InsertDependItemsBeforeReturn(); } -void FunctionBlock::SetStateAssgin(const AnfNodePtr& target, const std::string& readid) { +void FunctionBlock::SetStateAssgin(const AnfNodePtr &target, const std::string &readid) { state_assign_[target] = readid; } -void FunctionBlock::AddAutoDepend(const AnfNodePtr& target) { auto_depends_.push_back(target); } +void FunctionBlock::AddAutoDepend(const AnfNodePtr &target) { auto_depends_.push_back(target); } void FunctionBlock::InsertDependItemsBeforeReturn() { if (!prev_blocks_.empty()) { - for (auto& prev_block : prev_blocks_) { + for (auto &prev_block : prev_blocks_) { MS_LOG(DEBUG) << "Has prev_block " << prev_block->func_graph()->debug_info().get(); } } @@ -324,14 +324,14 @@ void FunctionBlock::InsertDependItemsBeforeReturn() { AnfNodePtr state = nullptr; std::vector vec_states; vec_states.emplace_back(make_tuple_op); - for (auto& item : state_assign_) { + for (auto &item : state_assign_) { auto source = ReadVariable(item.second); auto refkey = func_graph()->NewCNode({get_refkey_op, item.first}); auto assign = func_graph()->NewCNode({assign_op, refkey, source}); MS_LOG(INFO) << "SetState read " << item.first->ToString() << ", " << item.second; vec_states.emplace_back(assign); } - for (auto& item : auto_depends_) { + for (auto &item : auto_depends_) { MS_LOG(DEBUG) << "auto_depends " << item->ToString(); vec_states.emplace_back(item); } diff --git a/mindspore/ccsrc/pipeline/parse/function_block.h b/mindspore/ccsrc/pipeline/parse/function_block.h index 0be6e472f8..e7842903ee 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/parse/function_block.h @@ -43,47 +43,47 @@ using FunctionBlockPtr = std::shared_ptr; // the original source code. class FunctionBlock : public std::enable_shared_from_this { public: - explicit FunctionBlock(const Parser& parser); + explicit FunctionBlock(const Parser &parser); virtual ~FunctionBlock() {} FuncGraphPtr func_graph() { return func_graph_; } - void WriteVariable(const std::string& var_name, const AnfNodePtr& node); - AnfNodePtr ReadVariable(const std::string& var_name); - void AddPrevBlock(const FunctionBlockPtr& block); - void SetPhiArgument(const ParameterPtr& phi); - void CollectRemovablePhi(const ParameterPtr& phi); + void WriteVariable(const std::string &var_name, const AnfNodePtr &node); + AnfNodePtr ReadVariable(const std::string &var_name); + void AddPrevBlock(const FunctionBlockPtr &block); + void SetPhiArgument(const ParameterPtr &phi); + void CollectRemovablePhi(const ParameterPtr &phi); // A block is matured if all its predecessors is generated void Mature(); - CNodePtr ForceToBoolNode(const AnfNodePtr& cond); - void Jump(const FunctionBlockPtr& block, AnfNodePtr node); - AnfNodePtr SearchReplaceNode(const std::string& var, const ParameterPtr& phi); - void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr& trueBlock, const FunctionBlockPtr& falseBlock); + CNodePtr ForceToBoolNode(const AnfNodePtr &cond); + void Jump(const FunctionBlockPtr &block, AnfNodePtr node); + AnfNodePtr SearchReplaceNode(const std::string &var, const ParameterPtr &phi); + void ConditionalJump(AnfNodePtr condNode, const FunctionBlockPtr &trueBlock, const FunctionBlockPtr &falseBlock); // record the assign statement of self.xx weight parameter ,which will use state_setitem op - void SetStateAssgin(const AnfNodePtr& target, const std::string& readid); - void AddAutoDepend(const AnfNodePtr& target); + void SetStateAssgin(const AnfNodePtr &target, const std::string &readid); + void AddAutoDepend(const AnfNodePtr &target); void InsertDependItemsBeforeReturn(); - void AddGlobalVar(const std::string& var_name) { (void)global_vars_.insert(var_name); } - bool IsGlobalVar(const std::string& var_name) { return global_vars_.find(var_name) != global_vars_.end(); } - AnfNodePtr MakeResolveAstOp(const py::object& op); + void AddGlobalVar(const std::string &var_name) { (void)global_vars_.insert(var_name); } + bool IsGlobalVar(const std::string &var_name) { return global_vars_.find(var_name) != global_vars_.end(); } + AnfNodePtr MakeResolveAstOp(const py::object &op); AnfNodePtr MakeResolveClassMember(std::string attr); - AnfNodePtr MakeResolveSymbol(const std::string& value); - AnfNodePtr MakeResolveOperation(const std::string& value); - AnfNodePtr MakeResolve(const std::shared_ptr& name_space, const std::shared_ptr& resolve_symbol); - const std::unordered_map& removable_phis() const { return removable_phis_; } + AnfNodePtr MakeResolveSymbol(const std::string &value); + AnfNodePtr MakeResolveOperation(const std::string &value); + AnfNodePtr MakeResolve(const std::shared_ptr &name_space, const std::shared_ptr &resolve_symbol); + const std::unordered_map &removable_phis() const { return removable_phis_; } private: // block graph FuncGraphPtr func_graph_; // the block's parser - const Parser& parser_; + const Parser &parser_; // A block is matured if all its prev_blocks is processed bool matured_; // store the nest-level block // refer to comments in Parser::func_block_list_; - std::vector prev_blocks_; + std::vector prev_blocks_; // store args and variable's node std::map vars_; @@ -93,7 +93,7 @@ class FunctionBlock : public std::enable_shared_from_this { // jumps map the successor block and the function call that perform jump // refer to comments in Parser::func_block_list_ that how to break the cyclic reference - std::map jumps_; + std::map jumps_; // keeps all removable phis which will be removed in one pass. std::unordered_map removable_phis_; diff --git a/mindspore/ccsrc/pipeline/parse/parse_base.h b/mindspore/ccsrc/pipeline/parse/parse_base.h index df2d1968a5..aad8be0d6e 100644 --- a/mindspore/ccsrc/pipeline/parse/parse_base.h +++ b/mindspore/ccsrc/pipeline/parse/parse_base.h @@ -128,15 +128,15 @@ enum ClassInstanceTypeDef { }; // Convert python object to ValuePtr -bool ConvertData(const py::object& obj, ValuePtr* data, bool use_signature = false); +bool ConvertData(const py::object &obj, ValuePtr *data, bool use_signature = false); // Convert python obj to graph -FuncGraphPtr ConvertToFuncGraph(const py::object& obj, - const std::string& python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +FuncGraphPtr ConvertToFuncGraph(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); // Parse the python object to graph -FuncGraphPtr ParsePythonCode(const py::object& obj, - const std::string& python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); +FuncGraphPtr ParsePythonCode(const py::object &obj, + const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.cc b/mindspore/ccsrc/pipeline/parse/python_adapter.cc index e2c86164d4..df2f7d0d45 100644 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.cc +++ b/mindspore/ccsrc/pipeline/parse/python_adapter.cc @@ -32,7 +32,7 @@ void set_use_signature_in_resolve(bool use_signature) noexcept { use_signature_i bool UseSignatureInResolve() { return use_signature_in_resolve_; } void set_python_env_flag(bool python_env) noexcept { python_env_ = python_env; } bool IsPythonEnv() { return python_env_; } -void SetPythonPath(const std::string& path) { +void SetPythonPath(const std::string &path) { // load the python module path (void)python_adapter::set_python_scoped(); py::module sys = py::module::import("sys"); @@ -62,7 +62,7 @@ std::shared_ptr set_python_scoped() { } // return the module of python -py::module GetPyModule(const std::string& module) { +py::module GetPyModule(const std::string &module) { if (!module.empty()) { return py::module::import(module.c_str()); } else { @@ -71,7 +71,7 @@ py::module GetPyModule(const std::string& module) { } // Get the obj of attr -py::object GetPyObjAttr(const py::object& obj, const std::string& attr) { +py::object GetPyObjAttr(const py::object &obj, const std::string &attr) { if (!attr.empty() && !py::isinstance(obj)) { if (py::hasattr(obj, attr.c_str())) { return obj.attr(attr.c_str()); @@ -81,7 +81,7 @@ py::object GetPyObjAttr(const py::object& obj, const std::string& attr) { return py::none(); } -py::object GetPyFn(const std::string& module, const std::string& name) { +py::object GetPyFn(const std::string &module, const std::string &name) { (void)python_adapter::set_python_scoped(); if (!module.empty() && !name.empty()) { py::module mod = py::module::import(module.c_str()); diff --git a/mindspore/ccsrc/pipeline/parse/python_adapter.h b/mindspore/ccsrc/pipeline/parse/python_adapter.h index 12cfc27186..98adcd4f73 100644 --- a/mindspore/ccsrc/pipeline/parse/python_adapter.h +++ b/mindspore/ccsrc/pipeline/parse/python_adapter.h @@ -31,10 +31,10 @@ namespace mindspore { namespace parse { // A utility to call python interface namespace python_adapter { -py::module GetPyModule(const std::string& module); -py::object GetPyObjAttr(const py::object& obj, const std::string& attr); +py::module GetPyModule(const std::string &module); +py::object GetPyObjAttr(const py::object &obj, const std::string &attr); template -py::object CallPyObjMethod(const py::object& obj, const std::string& method, T... args) { +py::object CallPyObjMethod(const py::object &obj, const std::string &method, T... args) { if (!method.empty() && !py::isinstance(obj)) { return obj.attr(method.c_str())(args...); } @@ -43,7 +43,7 @@ py::object CallPyObjMethod(const py::object& obj, const std::string& method, T.. // call python function of module template -py::object CallPyModFn(const py::module& mod, const std::string& function, T... args) { +py::object CallPyModFn(const py::module &mod, const std::string &function, T... args) { if (!function.empty() && !py::isinstance(mod)) { return mod.attr(function.c_str())(args...); } @@ -57,12 +57,12 @@ bool UseSignatureInResolve(); std::shared_ptr set_python_scoped(); void ResetPythonScope(); bool IsPythonEnv(); -void SetPythonPath(const std::string& path); +void SetPythonPath(const std::string &path); void set_python_env_flag(bool python_env) noexcept; -py::object GetPyFn(const std::string& module, const std::string& name); +py::object GetPyFn(const std::string &module, const std::string &name); // Call the python function template -py::object CallPyFn(const std::string& module, const std::string& name, T... args) { +py::object CallPyFn(const std::string &module, const std::string &name, T... args) { (void)set_python_scoped(); if (!module.empty() && !name.empty()) { py::module mod = py::module::import(module.c_str()); diff --git a/mindspore/ccsrc/pipeline/parse/resolve.cc b/mindspore/ccsrc/pipeline/parse/resolve.cc index f90fc5039c..284512c943 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/parse/resolve.cc @@ -71,7 +71,7 @@ bool SymbolResolver::Resolve() { namespace { // argument obj should be python Parameter object // it will be converted to Parameter node here -AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& obj) { +AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) { MS_EXCEPTION_IF_NULL(func_graph); // parameter object should not be none @@ -128,7 +128,7 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object& } } -bool ResolveObjectToNode(const FuncGraphPtr& func_graph, const py::object& obj, AnfNodePtr* const node) { +bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { AnfNodePtr output = nullptr; if (py::hasattr(obj, "__parameter__")) { auto param = ResolveParameterObj(func_graph, obj); @@ -171,12 +171,12 @@ bool ResolveObjectToNode(const FuncGraphPtr& func_graph, const py::object& obj, } // transform the ValueTuple or ValueList of graph node to make tuple of const graph node -bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const AnfNodePtr& node, - const ValueNodePtr& value_node, AnfNodePtr* const transformed) { +bool TransformVectorGraphValueNode(const FuncGraphManagerPtr &manager, const AnfNodePtr &node, + const ValueNodePtr &value_node, AnfNodePtr *const transformed) { MS_EXCEPTION_IF_NULL(value_node); - const auto& value_vec = GetValue>(value_node->value()); + const auto &value_vec = GetValue>(value_node->value()); bool has_graph_in_list = false; - for (auto& elemv : value_vec) { + for (auto &elemv : value_vec) { MS_EXCEPTION_IF_NULL(elemv); if (elemv->isa()) { FuncGraphPtr new_fg = elemv->cast(); @@ -196,10 +196,10 @@ bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const Anf auto make_list_op = NewValueNode(prim::kPrimMakeTuple); list_vec.emplace_back(make_list_op); (void)std::transform(std::begin(value_vec), std::end(value_vec), std::back_inserter(list_vec), - [](const ValuePtr& value) { return NewValueNode(value); }); + [](const ValuePtr &value) { return NewValueNode(value); }); FuncGraphPtr cnode_graph = nullptr; auto users = manager->node_users()[node]; - for (auto& use : users) { + for (auto &use : users) { auto use_node = use.first; MS_EXCEPTION_IF_NULL(use_node); if (use_node->isa()) { @@ -220,8 +220,8 @@ bool TransformVectorGraphValueNode(const FuncGraphManagerPtr& manager, const Anf } } // namespace -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& name_space, const SymbolPtr& symbol, - const AnfNodePtr& node) { +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node) { if (node->func_graph() == nullptr || manager == nullptr) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << " graph or manager is nullptr"; } @@ -253,7 +253,7 @@ AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& } namespace { -opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib& irpass) { +opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib &irpass) { opt::OptPassGroupMap map({ {"resolve", { @@ -266,7 +266,7 @@ opt::OptPassGroupMap GetOptResolvePasses(const opt::irpass::ResolveIRPassLib& ir } } // namespace -bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBasePtr& res, bool use_profile) { +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile) { if (func_graph == nullptr || res == nullptr) { MS_LOG(ERROR) << "func_graph or resource is null"; return false; @@ -282,7 +282,7 @@ bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBa return true; } -bool ResolveAll(const FuncGraphManagerPtr& manager) { +bool ResolveAll(const FuncGraphManagerPtr &manager) { if (manager == nullptr) { MS_LOG(ERROR) << "func graph manager is null"; return false; @@ -301,7 +301,7 @@ bool ResolveAll(const FuncGraphManagerPtr& manager) { res->set_manager(manager); auto roots = manager->roots(); - for (auto& fg : roots) { + for (auto &fg : roots) { bool ret = ResolveFuncGraph(fg, res, false); if (!ret) { MS_EXCEPTION_IF_NULL(fg); diff --git a/mindspore/ccsrc/pipeline/parse/resolve.h b/mindspore/ccsrc/pipeline/parse/resolve.h index ccc22c72dc..acabfaf54b 100644 --- a/mindspore/ccsrc/pipeline/parse/resolve.h +++ b/mindspore/ccsrc/pipeline/parse/resolve.h @@ -39,7 +39,7 @@ namespace parse { // NameSpace class for resolving python code. class NameSpace : public Named { public: - NameSpace(const std::string& module, const py::object& obj) : Named(module), module_(module), obj_(obj) {} + NameSpace(const std::string &module, const py::object &obj) : Named(module), module_(module), obj_(obj) {} ~NameSpace() override = default; MS_DECLARE_PARENT(NameSpace, Named); @@ -60,8 +60,8 @@ using NameSpacePtr = std::shared_ptr; // Symbol in NameSpace or Class which shall be resolved. class Symbol : public Named { public: - explicit Symbol(const std::string& symbol) : Named(symbol), symbol_(symbol) {} - explicit Symbol(const std::string& symbol, const std::string& name) : Named(name), symbol_(symbol) {} + explicit Symbol(const std::string &symbol) : Named(symbol), symbol_(symbol) {} + explicit Symbol(const std::string &symbol, const std::string &name) : Named(name), symbol_(symbol) {} ~Symbol() override = default; MS_DECLARE_PARENT(Symbol, Named); @@ -79,7 +79,7 @@ using SymbolPtr = std::shared_ptr; // PyObjectWrapper class wrappers resolved python object for further processing. class PyObjectWrapper : public Named { public: - explicit PyObjectWrapper(const py::object& obj, const std::string name = "Python object") : Named(name), obj_(obj) {} + explicit PyObjectWrapper(const py::object &obj, const std::string name = "Python object") : Named(name), obj_(obj) {} ~PyObjectWrapper() override = default; MS_DECLARE_PARENT(PyObjectWrapper, Named); py::object obj() { return obj_; } @@ -92,7 +92,7 @@ class PyObjectWrapper : public Named { // ClassObject class wrappers dataclass class ClassObject : public PyObjectWrapper { public: - explicit ClassObject(const py::object& obj, const std::string name = "Python dataclass") + explicit ClassObject(const py::object &obj, const std::string name = "Python dataclass") : PyObjectWrapper(obj, name) {} ~ClassObject() override = default; MS_DECLARE_PARENT(ClassObject, PyObjectWrapper); @@ -102,7 +102,7 @@ class ClassObject : public PyObjectWrapper { // ClassType class wrappers class name in python class ClassType : public PyObjectWrapper { public: - explicit ClassType(const py::object& obj, const std::string name = "Python class type") + explicit ClassType(const py::object &obj, const std::string name = "Python class type") : PyObjectWrapper(obj, name) {} ~ClassType() override = default; MS_DECLARE_PARENT(ClassType, PyObjectWrapper); @@ -112,7 +112,7 @@ class ClassType : public PyObjectWrapper { // SymbolResolver class for resolving symbol extracted from AnfNode. class SymbolResolver { public: - SymbolResolver(const NameSpacePtr& name_space, const SymbolPtr& symbol, const AnfNodePtr& node) + SymbolResolver(const NameSpacePtr &name_space, const SymbolPtr &symbol, const AnfNodePtr &node) : namespace_(name_space), symbol_(symbol), resolved_node_(node) {} ~SymbolResolver() = default; @@ -124,7 +124,7 @@ class SymbolResolver { SymbolPtr symbol() { return symbol_; } - py::object& result() { return result_; } + py::object &result() { return result_; } AnfNodePtr resolved_node() { return resolved_node_; } @@ -141,15 +141,15 @@ class SymbolResolver { }; using SymbolResolverPtr = std::shared_ptr; // Resolve symbol in namespace. -AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr& manager, const NameSpacePtr& name_space, const SymbolPtr& symbol, - const AnfNodePtr& node); +AnfNodePtr ResolveSymbol(const FuncGraphManagerPtr &manager, const NameSpacePtr &name_space, const SymbolPtr &symbol, + const AnfNodePtr &node); // Resolve one graph which normally is the root graph. FuncGraph shall be managed by res->manager(). -bool ResolveFuncGraph(const FuncGraphPtr& func_graph, const pipeline::ResourceBasePtr& res, bool use_profile = true); +bool ResolveFuncGraph(const FuncGraphPtr &func_graph, const pipeline::ResourceBasePtr &res, bool use_profile = true); // Resolve all graphs in manager which is defined outside of pipeline::Resource. // Mainly used for test cases or resolve graphs which will not be managed by manager. -bool ResolveAll(const FuncGraphManagerPtr& manager); +bool ResolveAll(const FuncGraphManagerPtr &manager); } // namespace parse } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index b3eda4c37b..6cdf641443 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -48,7 +48,7 @@ using abstract::AnalysisResult; using mindspore::abstract::AnalysisContextPtr; using mindspore::validator::Validate; -bool SimplifyDataStructuresPass(const ResourcePtr& res) { +bool SimplifyDataStructuresPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); @@ -57,7 +57,7 @@ bool SimplifyDataStructuresPass(const ResourcePtr& res) { abstract::AbstractBasePtrList args_spec; auto parameters = func_graph->parameters(); (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), - [](const AnfNodePtr& p) -> AbstractBasePtr { return p->abstract(); }); + [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); res->set_func_graph(new_fg); res->set_args_spec(args_spec); @@ -65,7 +65,7 @@ bool SimplifyDataStructuresPass(const ResourcePtr& res) { } namespace { -OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig a_1 = opt::OptPassConfig({ irpass.switch_simplify_, @@ -133,7 +133,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) { return map_a; } -OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig b_1 = opt::OptPassConfig({ irpass.zero_like_fill_zero_, irpass.item_tuple_eliminate_, @@ -157,7 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib& irpass) { return map; } -OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig control_group = opt::OptPassConfig({irpass.convert_switch_replacement_}); OptPassGroupMap map({ {"control_group", control_group}, @@ -173,7 +173,7 @@ OptPassGroupMap GetInferenceOptPreparePhases() { return prepare_map; } -OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { +OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib &irpass) { opt::OptPassConfig prepare_group = opt::OptPassConfig({irpass.print_tuple_wrapper_}); OptPassGroupMap map({{"prepare_group", prepare_group}}); return map; @@ -181,7 +181,7 @@ OptPassGroupMap GetPreparePhases(const opt::irpass::OptimizeIRPassLib& irpass) { static std::unordered_map> g_pass_opts = {}; -void InitOpt(const ResourcePtr& res) { +void InitOpt(const ResourcePtr &res) { if (g_pass_opts.size() == 0) { opt::irpass::OptimizeIRPassLib irpass; g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass)); @@ -193,13 +193,13 @@ void InitOpt(const ResourcePtr& res) { } // namespace void ReclaimOptimizer() { - for (auto& opt : g_pass_opts) { + for (auto &opt : g_pass_opts) { opt.second = nullptr; } g_pass_opts.clear(); } -bool OptPassGroup(const ResourcePtr& res, const std::string& name) { +bool OptPassGroup(const ResourcePtr &res, const std::string &name) { if (res->func_graph() == nullptr) { MS_LOG(ERROR) << "Opt passes int error"; return false; @@ -216,12 +216,12 @@ bool OptPassGroup(const ResourcePtr& res, const std::string& name) { return true; } -bool OptPassAGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_a"); } -bool OptPassBGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_b"); } -bool ControlGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_control"); } -bool PrepareGroup(const ResourcePtr& res) { return OptPassGroup(res, "opt_prepare"); } +bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); } +bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); } +bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); } +bool PrepareGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_prepare"); } -bool AddControlDependPass(const ResourcePtr& res) { +bool AddControlDependPass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -237,7 +237,7 @@ bool AddControlDependPass(const ResourcePtr& res) { return true; } -bool CconvPass(const ResourcePtr& res) { +bool CconvPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); FuncGraphPtr new_fg = LiftingClone(func_graph); @@ -245,14 +245,14 @@ bool CconvPass(const ResourcePtr& res) { return true; } -bool ValidatePass(const ResourcePtr& res) { +bool ValidatePass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); Validate(func_graph); return true; } -bool InferenceOptPreparePass(const ResourcePtr& res) { +bool InferenceOptPreparePass(const ResourcePtr &res) { FuncGraphPtr func_graph = res->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); abstract::AbstractBasePtrList args_spec = res->args_spec(); diff --git a/mindspore/ccsrc/pipeline/pass.h b/mindspore/ccsrc/pipeline/pass.h index 3731d7e524..2636879d01 100644 --- a/mindspore/ccsrc/pipeline/pass.h +++ b/mindspore/ccsrc/pipeline/pass.h @@ -30,11 +30,11 @@ using PassItem = std::pair>; extern std::vector kGePasses; extern std::vector kVmPasses; -bool CconvPass(const ResourcePtr& res); -bool ValidatePass(const ResourcePtr& res); -bool ConvertPrepareAdapt(const ResourcePtr& res); -bool AddControlDependPass(const ResourcePtr& res); -bool InferenceOptPreparePass(const ResourcePtr& res); +bool CconvPass(const ResourcePtr &res); +bool ValidatePass(const ResourcePtr &res); +bool ConvertPrepareAdapt(const ResourcePtr &res); +bool AddControlDependPass(const ResourcePtr &res); +bool InferenceOptPreparePass(const ResourcePtr &res); void ReclaimOptimizer(); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index cd4fe28db9..5b5cae4044 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -67,7 +67,7 @@ std::unordered_map& defaults) { +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults) { MS_LOG(DEBUG) << "GenerateKey args size:" << defaults.size(); abstract::AbstractBasePtrList args_spec; @@ -147,7 +147,7 @@ py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple i ExecutorPy::ExecutorPy() {} -ResourcePtr ExecutorPy::GetResource(const std::string& phase) { +ResourcePtr ExecutorPy::GetResource(const std::string &phase) { MS_LOG(DEBUG) << "Phase size:" << info_.size(); if (info_.count(phase) == 0) { return nullptr; @@ -155,21 +155,21 @@ ResourcePtr ExecutorPy::GetResource(const std::string& phase) { return info_[phase]->resource; } -FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string& phase) { +FuncGraphPtr ExecutorPy::GetFuncGraph(const std::string &phase) { if (info_.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } return info_[phase]->func_graph; } -std::size_t ExecutorPy::ArgListSize(const std::string& phase) { +std::size_t ExecutorPy::ArgListSize(const std::string &phase) { if (info_.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } return info_[phase]->arg_list_size; } -compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string& phase) { +compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string &phase) { ResourcePtr res = GetResource(phase); MS_EXCEPTION_IF_NULL(res); if (res->results().find(kOutput) != res->results().end() && res->results()[kOutput].is()) { @@ -179,17 +179,17 @@ compile::VmEvalFuncPtr ExecutorPy::GetVmEvalFunc(const std::string& phase) { return nullptr; } -bool ExecutorPy::HasCompiled(const std::string& phase) const { +bool ExecutorPy::HasCompiled(const std::string &phase) const { if (info_.count(phase) == 0) { return false; } return true; } -py::bytes ExecutorPy::GetFuncGraphProto(const std::string& phase, const std::string& ir_type) { +py::bytes ExecutorPy::GetFuncGraphProto(const std::string &phase, const std::string &ir_type) { FuncGraphPtr fg_ptr = GetFuncGraph(phase); if (fg_ptr == nullptr) { - for (auto& item : info_) { + for (auto &item : info_) { MS_LOG(DEBUG) << "Phase key is: " << item.first; } MS_LOG(EXCEPTION) << "Can not find func graph " << phase; @@ -214,34 +214,34 @@ py::bytes ExecutorPy::GetFuncGraphProto(const std::string& phase, const std::str MS_LOG(EXCEPTION) << "Unknown ir type: " << ir_type; } -py::dict ExecutorPy::GetParameterLayout(const std::string& phase) { +py::dict ExecutorPy::GetParameterLayout(const std::string &phase) { MS_LOG(DEBUG) << "GetParameterLayout!"; std::string layout_graph = phase + kStepParallelGraph; auto graph = GetFuncGraph(layout_graph); return mindspore::parallel::GetParameterLayout(graph); } -py::dict ExecutorPy::GetCNodeStrategy(const std::string& phase) { +py::dict ExecutorPy::GetCNodeStrategy(const std::string &phase) { MS_LOG(DEBUG) << "GetCNodeStrategy!"; std::string layout_graph = phase + kStepParallelGraph; auto graph = GetFuncGraph(layout_graph); return mindspore::parallel::GetCNodeStrategy(graph); } -py::dict ExecutorPy::GetAllreduceFusion(const std::string& phase) { +py::dict ExecutorPy::GetAllreduceFusion(const std::string &phase) { MS_LOG(INFO) << "GetAllreduceFusion!"; auto graph = GetFuncGraph(phase); return mindspore::parallel::GetAllreduceFusion(graph); } -void ExecutorPy::DelNetRes(const std::string& id) { +void ExecutorPy::DelNetRes(const std::string &id) { #ifdef ENABLE_GE FinalizeGe(); #endif if (executor_ != nullptr) { bool flag = false; auto tmp_info = info_; - for (auto& item : tmp_info) { + for (auto &item : tmp_info) { if (item.first.find(id) != string::npos) { MS_LOG(INFO) << "Delete network res:" << item.first; (void)info_.erase(item.first); @@ -271,7 +271,7 @@ ExecutorPy::~ExecutorPy() { ConfigManager::GetInstance().ResetConfig(); } -void ExecutorPy::SaveCompiledGraph(const std::string& phase_s) { +void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) { // save the graph to ExecutorPy FuncGraphPtr func_graph = info_[phase_s]->resource->func_graph(); MS_EXCEPTION_IF_NULL(func_graph); @@ -294,7 +294,7 @@ void ExecutorPy::SaveCompiledGraph(const std::string& phase_s) { MS_LOG(INFO) << "End save compiled func graph!"; } -bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string& phase_s) const { +bool ExecutorPy::ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const { std::string phase_prefix = GetPhasePrefix(phase_s); if (use_vm && phase_prefix == "export") { @@ -313,7 +313,7 @@ void ExecutorPy::GetGeBackendPolicy() const { } } -bool ExecutorPy::CompileInner(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm) { +bool ExecutorPy::CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { MS_LOG(DEBUG) << "Start ExecutorPy compile!"; if ((!py::isinstance(phase))) { MS_LOG(ERROR) << "Arg phase must be string."; @@ -376,7 +376,7 @@ bool ExecutorPy::CompileInner(const py::object& obj, const py::tuple& args, cons return true; } -void ExecutorPy::ReleaseResource(const py::object& phase) { +void ExecutorPy::ReleaseResource(const py::object &phase) { ResourcePtr res = GetResource(py::cast(phase)); if (res != nullptr) { res->Clean(); @@ -385,18 +385,18 @@ void ExecutorPy::ReleaseResource(const py::object& phase) { ReclaimOptimizer(); } -static std::string PrintArgs(const py::tuple& args) { +static std::string PrintArgs(const py::tuple &args) { py::print(args); return ""; } -bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm) { +bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm) { bool ret_value = false; try { MS_LOG(DEBUG) << PrintArgs(args); ret_value = CompileInner(obj, args, phase, use_vm); - } catch (const py::error_already_set& ex) { + } catch (const py::error_already_set &ex) { // print function call stack info before release std::ostringstream oss; trace::TraceGraphInfer(); @@ -409,13 +409,13 @@ bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py: // re-throw this exception to Python interpreter to handle it throw(py::error_already_set(ex)); - } catch (const py::type_error& ex) { + } catch (const py::type_error &ex) { ReleaseResource(phase); throw py::type_error(ex); - } catch (const py::value_error& ex) { + } catch (const py::value_error &ex) { ReleaseResource(phase); throw py::value_error(ex); - } catch (const std::exception& ex) { + } catch (const std::exception &ex) { ReleaseResource(phase); // re-throw this exception to Python interpreter to handle it throw(std::runtime_error(ex.what())); @@ -432,7 +432,7 @@ bool ExecutorPy::Compile(const py::object& obj, const py::tuple& args, const py: // get MindSpore Intermediate Representation File std::string GetMsIrFile(void) { std::string file; - const char* path = getenv("MS_IR_FILE"); + const char *path = getenv("MS_IR_FILE"); if (path == nullptr) { return file; } @@ -446,7 +446,7 @@ std::string GetMsIrFile(void) { return file; } -void RunPipelineAction(const ActionItem& action, pipeline::ResourcePtr resource, bool* result) { +void RunPipelineAction(const ActionItem &action, pipeline::ResourcePtr resource, bool *result) { MS_EXCEPTION_IF_NULL(resource); MS_EXCEPTION_IF_NULL(result); @@ -472,7 +472,7 @@ void RunPipelineAction(const ActionItem& action, pipeline::ResourcePtr resource, } auto manager = resource->manager(); MS_EXCEPTION_IF_NULL(manager); - for (auto& graph : graphs) { + for (auto &graph : graphs) { manager->AddFuncGraph(graph); } resource->set_func_graph(graphs[0]); @@ -491,9 +491,9 @@ void Pipeline::Run() { WITH(MsProfile::GetProfile())[&user_graph, this]() { int i = 0; - for (auto& action : actions_) { + for (auto &action : actions_) { #ifdef ENABLE_TIMELINE - DumpTime& dump_time = DumpTime::GetInstance(); + DumpTime &dump_time = DumpTime::GetInstance(); dump_time.Record(action.first, GetTime(), true); #endif bool result = true; @@ -575,7 +575,7 @@ void Pipeline::Run() { MS_LOG(INFO) << "End"; } -void ExecutorPy::ProcessVmArg(const py::tuple& args, const std::string& phase, VectorRef* arg_list) { +void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list) { std::size_t size = args.size(); for (std::size_t i = 0; i < size; i++) { @@ -604,7 +604,7 @@ void ExecutorPy::ProcessVmArg(const py::tuple& args, const std::string& phase, V } } -py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { +py::object ExecutorPy::Run(const py::tuple &args, const py::object &phase) { std::size_t size = args.size(); if (!py::isinstance(phase)) { MS_LOG(EXCEPTION) << "Run failed, phase input is not a str"; @@ -649,8 +649,8 @@ py::object ExecutorPy::Run(const py::tuple& args, const py::object& phase) { return BaseRefToPyData(value); } -FuncGraphPtr ExecutorPy::BuildGraph(const py::dict& init_params, const std::string& phase, - const py::object& broadcast_params) { +FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params) { #if (ENABLE_GE || ENABLE_D) return BuildDFGraph(info_, init_params, phase, broadcast_params); #else @@ -658,15 +658,15 @@ FuncGraphPtr ExecutorPy::BuildGraph(const py::dict& init_params, const std::stri #endif } -void ExecutorPy::RunInitGraph(const py::dict& init_params, const std::string& phase) { +void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { #if ENABLE_GE RunGEInitGraph(init_params, phase); #endif } -bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase) { +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { std::string name = MsContext::GetInstance()->backend_policy(); if (name == kMsConvert || name == kMsVm) { return InitExecDatasetVm(queue_name, iter_num, batch_size, types, shapes, input_indexes); @@ -682,16 +682,16 @@ bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t ba return false; } -bool InitExecDatasetVm(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes) { +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes) { MS_LOG(INFO) << "Start InitDataSet Entry"; std::vector int_input_indexes; (void)std::transform(input_indexes.begin(), input_indexes.end(), std::back_inserter(int_input_indexes), [](int64_t item) { return static_cast(item); }); std::vector> int_shapes; (void)std::transform(shapes.begin(), shapes.end(), std::back_inserter(int_shapes), - [](const std::vector& item) { + [](const std::vector &item) { std::vector vector_item; (void)std::transform(item.begin(), item.end(), std::back_inserter(vector_item), [](int64_t inner_item) { return static_cast(inner_item); }); @@ -774,7 +774,7 @@ void FinalizeHccl() { #endif } -void ExportGraph(const std::string& file_name, const std::string&, const std::string& phase) { +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase) { #if (ENABLE_GE || ENABLE_D) ExportDFGraph(file_name, phase); #endif diff --git a/mindspore/ccsrc/pipeline/pipeline.h b/mindspore/ccsrc/pipeline/pipeline.h index a0d7a19198..865c961ac1 100644 --- a/mindspore/ccsrc/pipeline/pipeline.h +++ b/mindspore/ccsrc/pipeline/pipeline.h @@ -43,7 +43,7 @@ namespace py = pybind11; class Pipeline { public: - Pipeline(const ResourcePtr& res, const std::vector& actions) : resource_(res), actions_(actions) {} + Pipeline(const ResourcePtr &res, const std::vector &actions) : resource_(res), actions_(actions) {} ~Pipeline() = default; @@ -69,35 +69,35 @@ class ExecutorPy : public std::enable_shared_from_this { ~ExecutorPy(); - void SaveCompiledGraph(const std::string& phase_s); - bool CompileInner(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm); - bool Compile(const py::object& obj, const py::tuple& args, const py::object& phase, bool use_vm); + void SaveCompiledGraph(const std::string &phase_s); + bool CompileInner(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); + bool Compile(const py::object &obj, const py::tuple &args, const py::object &phase, bool use_vm); - void ProcessVmArg(const py::tuple& args, const std::string& phase, VectorRef* arg_list); + void ProcessVmArg(const py::tuple &args, const std::string &phase, VectorRef *arg_list); // for pynative mode when use_vm is on - py::object Run(const py::tuple& args, const py::object& phase); - ResourcePtr GetResource(const std::string& phase); - FuncGraphPtr GetFuncGraph(const std::string& phase); - py::bytes GetFuncGraphProto(const std::string& phase, const std::string& type); - std::size_t ArgListSize(const std::string& phase); - compile::VmEvalFuncPtr GetVmEvalFunc(const std::string& phase); - bool HasCompiled(const std::string& phase) const; - - FuncGraphPtr BuildGraph(const py::dict& init_params, const std::string& phase, - const py::object& broadcast_params = {}); - void RunInitGraph(const py::dict& init_params, const std::string& phase); - py::dict GetParameterLayout(const std::string& phase); - py::dict GetCNodeStrategy(const std::string& phase); - py::dict GetAllreduceFusion(const std::string& phase); - void DelNetRes(const std::string& id); - void ReleaseResource(const py::object& phase); + py::object Run(const py::tuple &args, const py::object &phase); + ResourcePtr GetResource(const std::string &phase); + FuncGraphPtr GetFuncGraph(const std::string &phase); + py::bytes GetFuncGraphProto(const std::string &phase, const std::string &type); + std::size_t ArgListSize(const std::string &phase); + compile::VmEvalFuncPtr GetVmEvalFunc(const std::string &phase); + bool HasCompiled(const std::string &phase) const; + + FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, + const py::object &broadcast_params = {}); + void RunInitGraph(const py::dict &init_params, const std::string &phase); + py::dict GetParameterLayout(const std::string &phase); + py::dict GetCNodeStrategy(const std::string &phase); + py::dict GetAllreduceFusion(const std::string &phase); + void DelNetRes(const std::string &id); + void ReleaseResource(const py::object &phase); static void ClearRes(); private: ExecutorPy(); - void ConvertObjectToTensors(const py::dict& dict, std::map* tensors); - bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string& phase_s) const; + void ConvertObjectToTensors(const py::dict &dict, std::map *tensors); + bool ChangeExportGeirUseVmFlag(bool use_vm, const std::string &phase_s) const; void GetGeBackendPolicy() const; std::map info_; @@ -107,10 +107,10 @@ class ExecutorPy : public std::enable_shared_from_this { using ExecutorPyPtr = std::shared_ptr; // Generate a key for mapping function graph -py::tuple GenerateKey(const std::string& name, const std::unordered_map& defaults); +py::tuple GenerateKey(const std::string &name, const std::unordered_map &defaults); py::bool_ VerifyInputSignature(const py::list input_signature, const py::tuple inputs); -bool InitDistribute(const std::map& options); +bool InitDistribute(const std::map &options); void ResetOpId(); void InitHccl(); @@ -121,17 +121,17 @@ void FinalizeGe(); void ClearResAtexit(); void ReleaseGeTsd(); -void ExportGraph(const std::string& file_name, const std::string&, const std::string& phase); +void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); // init and exec dataset sub graph -bool InitExecDataset(const std::string& queue_name, int64_t iter_num, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase); +bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); // Build and run dataset subgraph for ms backend -bool InitExecDatasetVm(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes); +bool InitExecDatasetVm(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.cc b/mindspore/ccsrc/pipeline/pipeline_ge.cc index 6ce0ea5316..e3b10b73b0 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/pipeline_ge.cc @@ -46,7 +46,7 @@ using mindspore::transform::MeTensorPtr; using mindspore::transform::Status; using mindspore::transform::TransformUtil; -void DoExecNonInputGraph(const std::string& phase) { +void DoExecNonInputGraph(const std::string &phase) { std::vector ge_tensors; std::vector ge_outputs; transform::RunOptions run_options; @@ -68,7 +68,7 @@ void DoExecNonInputGraph(const std::string& phase) { } } -void SetGeOption(const std::map& options) { +void SetGeOption(const std::map &options) { ConfigManager::GetInstance().set_ge_initialize_options(options); } @@ -108,11 +108,11 @@ Status CreateSessionAndGraphRunner(bool is_training = true) { return Status::SUCCESS; } -bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase) { +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase) { std::vector ge_types; - (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr& i) -> int64_t { + (void)std::transform(types.begin(), types.end(), std::back_inserter(ge_types), [](const TypePtr &i) -> int64_t { return transform::TransformUtil::ConvertDataType(i->type_id()); }); @@ -145,7 +145,7 @@ bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batc return true; } -void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) { +void ConvertObjectToTensors(const py::dict &dict, TensorOrderMap *const tensors) { for (auto item : dict) { if ((!py::isinstance(item.first))) { MS_LOG(WARNING) << "Type of key of py_dict is not string, ignore it."; @@ -156,11 +156,11 @@ void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) if (py::isinstance(item.second.attr("default_input"))) { // convert float to tensor with shape([1]) tensor = std::make_shared(kNumberTypeFloat32, std::vector({1})); - *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); + *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); } else if (py::isinstance(item.second.attr("default_input"))) { // convert int to tensor with shape([1]) tensor = std::make_shared(kNumberTypeInt32, std::vector({1})); - *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); + *(static_cast(tensor->data_c(true))) = py::cast(item.second.attr("default_input")); } else if (py::hasattr(item.second.attr("default_input"), PYTHON_TENSOR_FLAG)) { // cast tensor tensor = py::cast>(item.second.attr("default_input")); @@ -173,8 +173,8 @@ void ConvertObjectToTensors(const py::dict& dict, TensorOrderMap* const tensors) } } -bool AddDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params) { +bool AddDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { FuncGraphPtr anf_graph = info.at(phase)->func_graph; DfGraphConvertor convertor(anf_graph); @@ -237,8 +237,8 @@ bool AddDFGraph(const std::map& info, const py::di return true; } -FuncGraphPtr BuildDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params) { +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params) { if (info.count(phase) == 0) { MS_LOG(EXCEPTION) << "No phase in executor:" << GetPhasePrefix(phase); } @@ -268,13 +268,13 @@ FuncGraphPtr BuildDFGraph(const std::map& info, co return anf_graph; } -void RunGEInitGraph(const py::dict& init_params, const std::string& phase) { +void RunGEInitGraph(const py::dict &init_params, const std::string &phase) { MS_LOG(DEBUG) << "ExecInitGraph start."; TensorOrderMap inputs_with_name{}; ConvertObjectToTensors(init_params, &inputs_with_name); std::vector inputs; (void)std::transform(inputs_with_name.begin(), inputs_with_name.end(), std::back_inserter(inputs), - [](const std::pair& item) { return item.second; }); + [](const std::pair &item) { return item.second; }); std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); if (ge_tensors.size() != inputs.size()) { @@ -317,7 +317,7 @@ void RunGEInitGraph(const py::dict& init_params, const std::string& phase) { } } -py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::tuple& data, size_t* count) { +py::object ExtractGeneralCnodeRet(const AbstractBasePtr &cnode_data, const py::tuple &data, size_t *count) { MS_EXCEPTION_IF_NULL(cnode_data); if (*count >= data.size()) { MS_LOG(EXCEPTION) << "The number of elements in the outputs : " << data.size() @@ -350,7 +350,7 @@ py::object ExtractGeneralCnodeRet(const AbstractBasePtr& cnode_data, const py::t return std::move(tp); } -py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, size_t* count) { +py::object StructureOutput(const AnfNodePtr &output_node, const py::tuple &data, size_t *count) { MS_EXCEPTION_IF_NULL(output_node); if (output_node->isa()) { @@ -387,8 +387,8 @@ py::object StructureOutput(const AnfNodePtr& output_node, const py::tuple& data, return ExtractGeneralCnodeRet(output_c->abstract(), data, count); } -std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::vector& inputs, - const std::string& phase) { +std::shared_ptr DoExecGraph(const FuncGraphPtr &graph, const std::vector &inputs, + const std::string &phase) { std::vector ge_tensors = TransformUtil::ConvertInputTensors(inputs, kOpFormat_NCHW); if (ge_tensors.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Convert me args to ge tensor error."; @@ -438,8 +438,8 @@ std::shared_ptr DoExecGraph(const FuncGraphPtr& graph, const std::ve return ret; } -void ProcessGeArg(const std::map& info, const py::tuple& args, const std::string& phase, - std::vector* inputs) { +void ProcessGeArg(const std::map &info, const py::tuple &args, const std::string &phase, + std::vector *inputs) { // check the arg and use the ExecutorPy args std::size_t size = args.size(); @@ -470,8 +470,8 @@ void ProcessGeArg(const std::map& info, const py:: } } -py::object ExecDFGraph(const std::map& info, const py::tuple& args, - const std::string& phase) { +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase) { std::string phase_prefix = GetPhasePrefix(phase); if (phase_prefix == "save") { @@ -514,7 +514,7 @@ py::object ExecDFGraph(const std::map& info, const MS_LOG(EXCEPTION) << "Exec graph failed"; } } -void ExportDFGraph(const std::string& file_name, const std::string& phase) { +void ExportDFGraph(const std::string &file_name, const std::string &phase) { MS_LOG(DEBUG) << "ExportGraph Begin"; transform::DfGraphWrapperPtr wrap_ptr = DfGraphManager::GetInstance().GetGraphByName(phase); if (wrap_ptr == nullptr) { diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.h b/mindspore/ccsrc/pipeline/pipeline_ge.h index c3779fd982..9dc1524682 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.h +++ b/mindspore/ccsrc/pipeline/pipeline_ge.h @@ -34,22 +34,22 @@ namespace pipeline { namespace py = pybind11; -void SetGeOption(const std::map& options); +void SetGeOption(const std::map &options); -void RunGEInitGraph(const py::dict& init_params, const std::string& phase); +void RunGEInitGraph(const py::dict &init_params, const std::string &phase); -py::object ExecDFGraph(const std::map& info, const py::tuple& args, - const std::string& phase = "train"); +py::object ExecDFGraph(const std::map &info, const py::tuple &args, + const std::string &phase = "train"); -FuncGraphPtr BuildDFGraph(const std::map& info, const py::dict& init_params, - const std::string& phase, const py::object& broadcast_params = {}); +FuncGraphPtr BuildDFGraph(const std::map &info, const py::dict &init_params, + const std::string &phase, const py::object &broadcast_params = {}); // init and exec dataset sub graph for GE backend -bool InitExecDatasetGe(const std::string& queue_name, int64_t size, int64_t batch_size, - const std::vector& types, const std::vector>& shapes, - const std::vector& input_indexes, const std::string& phase); +bool InitExecDatasetGe(const std::string &queue_name, int64_t size, int64_t batch_size, + const std::vector &types, const std::vector> &shapes, + const std::vector &input_indexes, const std::string &phase); -void ExportDFGraph(const std::string& file_name, const std::string& phase); +void ExportDFGraph(const std::string &file_name, const std::string &phase); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc index 7937c3e55f..0b7401345a 100644 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.cc +++ b/mindspore/ccsrc/pipeline/remove_value_node_dup.cc @@ -24,9 +24,9 @@ namespace mindspore { namespace pipeline { -void TryToDoReplace(FuncGraphManager* const manager, const AnfNodePtr& node, HashCache* const hash_cache, - HashValue* const hash_value) { - const auto& to_check_value = GetValueNode(node); +void TryToDoReplace(FuncGraphManager *const manager, const AnfNodePtr &node, HashCache *const hash_cache, + HashValue *const hash_value) { + const auto &to_check_value = GetValueNode(node); MS_EXCEPTION_IF_NULL(to_check_value); // Calculate hash value. @@ -46,14 +46,14 @@ void TryToDoReplace(FuncGraphManager* const manager, const AnfNodePtr& node, Has return; } - auto& bucket = bucket_iter->second; + auto &bucket = bucket_iter->second; // Check if need to replace node with value node already met. - for (const auto& v : bucket) { + for (const auto &v : bucket) { // Already met and cached. if (v == node) { return; } - const auto& existed_value = GetValueNode(v); + const auto &existed_value = GetValueNode(v); MS_EXCEPTION_IF_NULL(existed_value); auto equal = [&]() -> bool { if (existed_value->isa() && to_check_value->isa()) { diff --git a/mindspore/ccsrc/pipeline/remove_value_node_dup.h b/mindspore/ccsrc/pipeline/remove_value_node_dup.h index 8fbb3f2755..8f670c7dcf 100644 --- a/mindspore/ccsrc/pipeline/remove_value_node_dup.h +++ b/mindspore/ccsrc/pipeline/remove_value_node_dup.h @@ -27,7 +27,7 @@ namespace pipeline { using HashCache = std::unordered_map>; using HashValue = std::unordered_map; -void TryToDoReplace(FuncGraphManager* manager, const AnfNodePtr& node, HashCache* hash_cache, HashValue* hash_value); +void TryToDoReplace(FuncGraphManager *manager, const AnfNodePtr &node, HashCache *hash_cache, HashValue *hash_value); } // namespace pipeline } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/resource.cc b/mindspore/ccsrc/pipeline/resource.cc index 18695518be..50ccef2f44 100644 --- a/mindspore/ccsrc/pipeline/resource.cc +++ b/mindspore/ccsrc/pipeline/resource.cc @@ -32,7 +32,7 @@ namespace mindspore { // namespace to support opmap definition namespace pipeline { -MethodMap& GetMethodMap() { +MethodMap &GetMethodMap() { static MethodMap method_map = {{kObjectTypeString, { {"__bool__", std::string("str_bool")} // C.str_bool @@ -178,7 +178,7 @@ MethodMap& GetMethodMap() { return method_map; } -Resource::Resource(const py::object& obj) +Resource::Resource(const py::object &obj) : engine_(std::make_shared(abstract::GetPrimEvaluatorConstructors(), manager_)), input_(obj), is_cleaned_(false) {} @@ -197,7 +197,7 @@ Resource::~Resource() { if (!is_cleaned_) { try { Clean(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when cleaning resource. Error info " << e.what(); } catch (...) { MS_LOG(ERROR) << "Exception when cleaning resource."; @@ -205,9 +205,9 @@ Resource::~Resource() { } } -bool Resource::IsTypeInMethodMap(const TypeId& type) { +bool Resource::IsTypeInMethodMap(const TypeId &type) { TypeId type_id = NormalizeTypeId(type); - const MethodMap& method_map = GetMethodMap(); + const MethodMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter != method_map.end()) { return true; @@ -215,9 +215,9 @@ bool Resource::IsTypeInMethodMap(const TypeId& type) { return false; } -Any Resource::GetMethodPtr(const TypeId& type, const std::string& name) { +Any Resource::GetMethodPtr(const TypeId &type, const std::string &name) { TypeId type_id = NormalizeTypeId(type); - const MethodMap& method_map = GetMethodMap(); + const MethodMap &method_map = GetMethodMap(); auto iter = method_map.find(static_cast(type_id)); if (iter == method_map.end()) { MS_LOG(WARNING) << "Object type: " << type_id << " not in the method_map"; diff --git a/mindspore/ccsrc/pipeline/resource.h b/mindspore/ccsrc/pipeline/resource.h index 15ab70db14..0c1348fd94 100644 --- a/mindspore/ccsrc/pipeline/resource.h +++ b/mindspore/ccsrc/pipeline/resource.h @@ -46,7 +46,7 @@ class InferenceResource; using MethodMap = std::unordered_map>; -MethodMap& GetMethodMap(); +MethodMap &GetMethodMap(); class ResourceBase { public: @@ -56,20 +56,20 @@ class ResourceBase { FuncGraphManagerPtr manager() { return manager_; } // set a manager defined outside which will not manage the graphs. - void set_manager(const FuncGraphManagerPtr& manager) { manager_ = manager; } + void set_manager(const FuncGraphManagerPtr &manager) { manager_ = manager; } - std::unordered_map& results() { return results_; } + std::unordered_map &results() { return results_; } - void SetResult(const std::string& key, const Any& value) { results_[key] = value; } + void SetResult(const std::string &key, const Any &value) { results_[key] = value; } - Any GetResult(const std::string& key) { + Any GetResult(const std::string &key) { if (results_.count(key) == 0) { MS_LOG(EXCEPTION) << "this key is not in resource list:" << key; } return results_[key]; } - bool HasResult(const std::string& key) const { return results_.count(key) != 0; } + bool HasResult(const std::string &key) const { return results_.count(key) != 0; } std::unordered_map results_; @@ -81,23 +81,23 @@ using ResourceBasePtr = std::shared_ptr; class Resource : public ResourceBase { public: - explicit Resource(const py::object& obj = py::none()); + explicit Resource(const py::object &obj = py::none()); ~Resource() override; abstract::AnalysisEnginePtr engine() { return engine_; } - static bool IsTypeInMethodMap(const TypeId& type); + static bool IsTypeInMethodMap(const TypeId &type); - static Any GetMethodPtr(const TypeId& type, const std::string& name); + static Any GetMethodPtr(const TypeId &type, const std::string &name); - const py::object& input() const { return input_; } + const py::object &input() const { return input_; } FuncGraphPtr func_graph() const { return func_graph_; } - void set_func_graph(const FuncGraphPtr& func_graph) { func_graph_ = func_graph; } + void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = func_graph; } - const abstract::AbstractBasePtrList& args_spec() const { return args_spec_; } - void set_args_spec(const abstract::AbstractBasePtrList& args_spec) { args_spec_ = args_spec; } + const abstract::AbstractBasePtrList &args_spec() const { return args_spec_; } + void set_args_spec(const abstract::AbstractBasePtrList &args_spec) { args_spec_ = args_spec; } // Reclaim resource and clear the cache. // ExecutorPy::Compile() can be called multiple times, so cache diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc b/mindspore/ccsrc/pipeline/static_analysis/dshape.cc index 15aa71ba1e..183ec772ff 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/dshape.cc @@ -26,31 +26,31 @@ namespace mindspore { namespace abstract { // used for print BaseShape content -std::ostream& operator<<(std::ostream& os, const BaseShape& bs) { +std::ostream &operator<<(std::ostream &os, const BaseShape &bs) { os << bs.ToString(); return os; } -std::ostream& operator<<(std::ostream& os, const std::shared_ptr bs) { +std::ostream &operator<<(std::ostream &os, const std::shared_ptr bs) { MS_EXCEPTION_IF_NULL(bs); os << bs->ToString(); return os; } -bool BaseShape::operator==(const BaseShape& other) const { +bool BaseShape::operator==(const BaseShape &other) const { if (tid() != other.tid()) { return false; } return true; } -bool BaseShape::operator!=(const BaseShape& other) const { return !(*this == other); } +bool BaseShape::operator!=(const BaseShape &other) const { return !(*this == other); } std::string Shape::ToString() const { std::ostringstream buffer; bool f_begin = true; buffer << "("; - for (auto& x : shape_) { + for (auto &x : shape_) { if (!f_begin) { buffer << ", "; } else { @@ -72,11 +72,11 @@ std::string Shape::DumpText() const { return buffer.str(); } -bool Shape::operator==(const BaseShape& other) const { +bool Shape::operator==(const BaseShape &other) const { if (tid() != other.tid()) { return false; } - return shape_ == static_cast(other).shape_; + return shape_ == static_cast(other).shape_; } const int Shape::SHP_ANY; @@ -111,11 +111,11 @@ BaseShapePtrList SequeueShape::ElementsClone() const { } template -bool SequeueShape::SequeueEqual(const BaseShape& other) const { +bool SequeueShape::SequeueEqual(const BaseShape &other) const { if (tid() != other.tid()) { return false; } - auto other_shapes = static_cast(other).p_shapes_; + auto other_shapes = static_cast(other).p_shapes_; if (other_shapes.size() != p_shapes_.size()) { return false; } @@ -126,8 +126,8 @@ bool SequeueShape::SequeueEqual(const BaseShape& other) const { } return true; } -template bool SequeueShape::SequeueEqual(const BaseShape&) const; -template bool SequeueShape::SequeueEqual(const BaseShape&) const; +template bool SequeueShape::SequeueEqual(const BaseShape &) const; +template bool SequeueShape::SequeueEqual(const BaseShape &) const; const std::shared_ptr kNoShape = std::make_shared(); } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/static_analysis/dshape.h b/mindspore/ccsrc/pipeline/static_analysis/dshape.h index 6debe061c8..3e850e309b 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/dshape.h +++ b/mindspore/ccsrc/pipeline/static_analysis/dshape.h @@ -41,8 +41,8 @@ class BaseShape : public Base { ~BaseShape() override = default; MS_DECLARE_PARENT(BaseShape, Base) - virtual bool operator==(const BaseShape& other) const; - bool operator!=(const BaseShape& other) const; + virtual bool operator==(const BaseShape &other) const; + bool operator!=(const BaseShape &other) const; std::size_t hash() const override { return tid(); } // return a deep copy @@ -62,16 +62,16 @@ class Shape : public BaseShape { public: static const int SHP_ANY = -1; Shape() : shape_() {} - Shape(const std::initializer_list& list) : shape_(list) {} - explicit Shape(const std::vector& list) : shape_(list) {} + Shape(const std::initializer_list &list) : shape_(list) {} + explicit Shape(const std::vector &list) : shape_(list) {} ~Shape() override = default; MS_DECLARE_PARENT(Shape, BaseShape) std::string ToString() const override; std::string DumpText() const override; - bool operator==(const BaseShape& other) const override; + bool operator==(const BaseShape &other) const override; BaseShapePtr Clone() const override { return std::make_shared(shape_); } void Broaden() override; - std::vector& shape() { return shape_; } + std::vector &shape() { return shape_; } std::vector shape_; // use SHP_ANY to implement the any shape in python }; @@ -81,7 +81,7 @@ using ShapePtrList = std::vector; class SequeueShape : public BaseShape { public: SequeueShape() : p_shapes_() {} - explicit SequeueShape(const BaseShapePtrList& shapes) : p_shapes_(shapes) {} + explicit SequeueShape(const BaseShapePtrList &shapes) : p_shapes_(shapes) {} ~SequeueShape() override = default; MS_DECLARE_PARENT(SequeueShape, BaseShape) @@ -89,9 +89,9 @@ class SequeueShape : public BaseShape { BaseShapePtrList ElementsClone() const; template - bool SequeueEqual(const BaseShape& other) const; + bool SequeueEqual(const BaseShape &other) const; - const BaseShapePtrList& shape() const { return p_shapes_; } + const BaseShapePtrList &shape() const { return p_shapes_; } size_t size() const { return p_shapes_.size(); } const BaseShapePtr operator[](std::size_t dim) const { return p_shapes_[dim]; } @@ -103,7 +103,7 @@ using SequeueShapePtr = std::shared_ptr; class TupleShape : public SequeueShape { public: TupleShape() : SequeueShape() {} - explicit TupleShape(const BaseShapePtrList& shapes) : SequeueShape(shapes) {} + explicit TupleShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} ~TupleShape() override = default; MS_DECLARE_PARENT(TupleShape, SequeueShape) @@ -111,14 +111,14 @@ class TupleShape : public SequeueShape { BaseShapePtr Clone() const override { return std::make_shared(ElementsClone()); } - bool operator==(const BaseShape& other) const override { return SequeueEqual(other); } + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } }; using TupleShapePtr = std::shared_ptr; class ListShape : public SequeueShape { public: ListShape() : SequeueShape() {} - explicit ListShape(const BaseShapePtrList& shapes) : SequeueShape(shapes) {} + explicit ListShape(const BaseShapePtrList &shapes) : SequeueShape(shapes) {} ~ListShape() override = default; MS_DECLARE_PARENT(ListShape, SequeueShape) @@ -126,7 +126,7 @@ class ListShape : public SequeueShape { BaseShapePtr Clone() const override { return std::make_shared(SequeueShape::ElementsClone()); } - bool operator==(const BaseShape& other) const override { return SequeueEqual(other); } + bool operator==(const BaseShape &other) const override { return SequeueEqual(other); } }; using ListShapePtr = std::shared_ptr; } // namespace abstract diff --git a/mindspore/ccsrc/pipeline/validator.cc b/mindspore/ccsrc/pipeline/validator.cc index 0fe3218813..73a54bb180 100644 --- a/mindspore/ccsrc/pipeline/validator.cc +++ b/mindspore/ccsrc/pipeline/validator.cc @@ -39,7 +39,7 @@ using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractType; -void ValidateOperation(const AnfNodePtr& node) { +void ValidateOperation(const AnfNodePtr &node) { if (!IsValueNode(node)) { return; } @@ -60,7 +60,7 @@ void ValidateOperation(const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "Illegal primitive: " << prim->name(); } -void ValidateAbstract(const AnfNodePtr& node) { +void ValidateAbstract(const AnfNodePtr &node) { if (node == nullptr) { MS_LOG(WARNING) << "Node to validate is invalid"; return; @@ -105,11 +105,11 @@ void ValidateAbstract(const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "Illegal type in the graph: " << ptrBase->ToString(); } -void Validate(const FuncGraphPtr& fg) { +void Validate(const FuncGraphPtr &fg) { FuncGraphManagerPtr mgr = Manage(fg, false); MS_EXCEPTION_IF_NULL(mgr); - AnfNodeSet& all_nodes = mgr->all_nodes(); - for (const auto& anf_node : all_nodes) { + AnfNodeSet &all_nodes = mgr->all_nodes(); + for (const auto &anf_node : all_nodes) { ValidateOperation(anf_node); ValidateAbstract(anf_node); } diff --git a/mindspore/ccsrc/pipeline/validator.h b/mindspore/ccsrc/pipeline/validator.h index 9944078e6c..61f7470349 100644 --- a/mindspore/ccsrc/pipeline/validator.h +++ b/mindspore/ccsrc/pipeline/validator.h @@ -29,9 +29,9 @@ namespace mindspore { namespace validator { -void Validate(const FuncGraphPtr& func_graph); -void ValidateAbstract(const AnfNodePtr& node); -void ValidateOperation(const AnfNodePtr& node); +void Validate(const FuncGraphPtr &func_graph); +void ValidateAbstract(const AnfNodePtr &node); +void ValidateOperation(const AnfNodePtr &node); } // namespace validator } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc index f0077ef6cd..c9ef381f16 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc @@ -121,7 +121,7 @@ bool DynamicMemPoolBestFit::IsDivide(size_t tensor_size, size_t mem_buf_size) co return mem_buf_size - tensor_size >= DYNAMIC_MEM_ALIGN_SIZE; } -void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr& mem_buf) { +void DynamicMemPoolBestFit::DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf) { MS_EXCEPTION_IF_NULL(mem_buf); auto mem_block = FindMemBlock(mem_buf->device_addr_); MS_EXCEPTION_IF_NULL(mem_block); @@ -160,7 +160,7 @@ void DynamicMemPoolBestFit::FreeTensorMem(const DeviceMemPtr device_addr) { CombineMemBuf(mem_block, device_addr); } -void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr& mem_block, const DeviceMemPtr device_addr) { +void DynamicMemPoolBestFit::CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr) { MS_EXCEPTION_IF_NULL(mem_block); MS_EXCEPTION_IF_NULL(device_addr); auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h index dcf735814c..c628756070 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h @@ -61,7 +61,7 @@ class DynamicMemBlock { DynamicMemBlock() = default; DynamicMemBlock(DeviceMemPtr addr_base, size_t size) : device_addr_base_(addr_base), mem_block_size_(size) {} ~DynamicMemBlock() { block_all_mem_buf_map_.clear(); } - const DeviceMemPtr& device_addr() const { return device_addr_base_; } + const DeviceMemPtr &device_addr() const { return device_addr_base_; } size_t size() const { return mem_block_size_; } // The map of all memory buf in this memory block by device address. DeviceAddrMapMemBuf block_all_mem_buf_map_; @@ -92,8 +92,8 @@ class DynamicMemPoolBestFit { size_t used_mem_peak_statistics() const { return used_mem_peak_statistics_; } // The related interface of device memory real operation, needs override by device type. - virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr* addr) = 0; - virtual bool FreeDeviceMem(const DeviceMemPtr& addr) = 0; + virtual size_t AllocDeviceMem(size_t size, DeviceMemPtr *addr) = 0; + virtual bool FreeDeviceMem(const DeviceMemPtr &addr) = 0; virtual size_t free_mem_size() = 0; virtual size_t total_mem_size() = 0; @@ -113,14 +113,14 @@ class DynamicMemPoolBestFit { // Judge whether need divide the memory buf by alloc size and memory buf size. bool IsDivide(size_t tensor_size, size_t mem_buf_size) const; // Divide the memory buf by alloc size. - void DivideMemBuf(size_t size, const DynamicMemBufPtr& mem_buf); + void DivideMemBuf(size_t size, const DynamicMemBufPtr &mem_buf); // Find the memory block by deivce address. DynamicMemBlockPtr FindMemBlock(const DeviceMemPtr device_addr); // The Comparator of memory block by device address, because memory blocks are arranged in order by device address. static bool CmpMemBlock(const DeviceMemPtr device_addr, const DynamicMemBlockPtr mem_block); // Combine the memory buf when memory free, to avoid the memory fragmentation. - void CombineMemBuf(const DynamicMemBlockPtr& mem_block, const DeviceMemPtr device_addr); + void CombineMemBuf(const DynamicMemBlockPtr &mem_block, const DeviceMemPtr device_addr); // Erase the idle memory buf by size and device address when idle memory buf is combined. void EraseIdleMemBuf(size_t size, const DeviceMemPtr device_addr); diff --git a/mindspore/ccsrc/predict/generator/ir/ir_model.h b/mindspore/ccsrc/predict/generator/ir/ir_model.h index bf1c057b5f..82bd2aad3f 100644 --- a/mindspore/ccsrc/predict/generator/ir/ir_model.h +++ b/mindspore/ccsrc/predict/generator/ir/ir_model.h @@ -23,7 +23,7 @@ namespace mindspore { namespace generator { class IRModel { public: - void SetIrTaskInfos(const std::vector& ir_tasks); + void SetIrTaskInfos(const std::vector &ir_tasks); IRModel() = default; ~IRModel(); diff --git a/mindspore/ccsrc/pybind_api/api_register.h b/mindspore/ccsrc/pybind_api/api_register.h index 2c1b622f31..8bab751267 100644 --- a/mindspore/ccsrc/pybind_api/api_register.h +++ b/mindspore/ccsrc/pybind_api/api_register.h @@ -29,19 +29,19 @@ namespace py = pybind11; namespace mindspore { -using PybindDefineFunc = std::function; +using PybindDefineFunc = std::function; class PybindDefineRegister { public: - static void Register(const std::string& name, const PybindDefineFunc& fn) { + static void Register(const std::string &name, const PybindDefineFunc &fn) { return GetSingleton().RegisterFn(name, fn); } - PybindDefineRegister(const PybindDefineRegister&) = delete; + PybindDefineRegister(const PybindDefineRegister &) = delete; - PybindDefineRegister& operator=(const PybindDefineRegister&) = delete; + PybindDefineRegister &operator=(const PybindDefineRegister &) = delete; - static std::map& AllFuncs() { return GetSingleton().fns_; } + static std::map &AllFuncs() { return GetSingleton().fns_; } std::map fns_; @@ -50,14 +50,14 @@ class PybindDefineRegister { virtual ~PybindDefineRegister() = default; - static PybindDefineRegister& GetSingleton(); + static PybindDefineRegister &GetSingleton(); - void RegisterFn(const std::string& name, const PybindDefineFunc& fn) { fns_[name] = fn; } + void RegisterFn(const std::string &name, const PybindDefineFunc &fn) { fns_[name] = fn; } }; class PybindDefineRegisterer { public: - PybindDefineRegisterer(const std::string& name, const PybindDefineFunc& fn) { + PybindDefineRegisterer(const std::string &name, const PybindDefineFunc &fn) { PybindDefineRegister::Register(name, fn); } ~PybindDefineRegisterer() = default; diff --git a/mindspore/ccsrc/pynative/base.h b/mindspore/ccsrc/pynative/base.h index d8675adc9c..37ff000b04 100644 --- a/mindspore/ccsrc/pynative/base.h +++ b/mindspore/ccsrc/pynative/base.h @@ -58,7 +58,7 @@ struct OpExecInfo { py::dict op_attrs; }; using OpExecInfoPtr = std::shared_ptr; -OpExecInfoPtr GenerateOpExecInfo(const py::args& args); +OpExecInfoPtr GenerateOpExecInfo(const py::args &args); const std::set ignore_infer_prim = {"partial", "make_ref"}; diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 6a1ddf6a7e..0d18dfb577 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -43,7 +43,7 @@ const std::unordered_set vm_operators = {"partial", "depend", "make namespace mindspore { namespace pynative { -inline ValuePtr PyAttrValue(const py::object& obj) { +inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { @@ -52,11 +52,11 @@ inline ValuePtr PyAttrValue(const py::object& obj) { return converted_ret; } -py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { +py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::tuple &py_args) { auto signature = prim->signatures(); std::vector dtypes; (void)std::transform(signature.begin(), signature.end(), std::back_inserter(dtypes), - [](const Signature& sig) { return sig.dtype; }); + [](const Signature &sig) { return sig.dtype; }); int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue); if (dtypes.size() == 0 || static_cast(dtypes.size()) == empty_dtype_count) { return py_args; @@ -103,7 +103,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr& prim, const py::tuple& py_args) { return py_inputs; } -void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecInfo* const op_exec_info) { +void PynativeInfer(const PrimitivePyPtr &prim, const py::tuple &py_args, OpExecInfo *const op_exec_info) { size_t size = py_args.size(); AbstractBasePtrList args_spec_list; for (size_t i = 0; i < size; i++) { @@ -118,7 +118,7 @@ void PynativeInfer(const PrimitivePyPtr& prim, const py::tuple& py_args, OpExecI op_exec_info->abstract = infer_res; } -OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { +OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { if (args.size() != PY_ARGS_NUM) { MS_LOG(ERROR) << "Four args are needed by RunOp"; return nullptr; @@ -147,7 +147,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args& args) { return op_exec_info; } -std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) { +std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info) { MS_EXCEPTION_IF_NULL(op_exec_info); std::string graph_info; MS_EXCEPTION_IF_NULL(op_exec_info->abstract); @@ -167,7 +167,7 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr& op_exec_info) { return graph_info; } -py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_LOG(INFO) << "RunOpInVM start"; MS_EXCEPTION_IF_NULL(status); @@ -188,7 +188,7 @@ py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat return std::move(result); } -py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; auto ms_context = MsContext::GetInstance(); @@ -212,7 +212,7 @@ py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat } py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr op_exec_info, - PynativeStatusCode* const status) { + PynativeStatusCode *const status) { MS_EXCEPTION_IF_NULL(status); py::object result; switch (backend_policy) { @@ -248,7 +248,7 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn return result; } -py::tuple RunOp(const py::args& args) { +py::tuple RunOp(const py::args &args) { py::object result; // returns a null py::tuple on error py::tuple err_ret(0); diff --git a/mindspore/ccsrc/pynative/pynative_execute.h b/mindspore/ccsrc/pynative/pynative_execute.h index 17b5610bfd..c64c6b4b25 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pynative/pynative_execute.h @@ -33,9 +33,9 @@ namespace pynative { namespace py = pybind11; -py::object RunOpInVM(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status); +py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); -py::tuple RunOp(const py::args& args); +py::tuple RunOp(const py::args &args); } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.cc b/mindspore/ccsrc/pynative/pynative_execute_ge.cc index 180b0006ff..0bf2a391f9 100644 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.cc +++ b/mindspore/ccsrc/pynative/pynative_execute_ge.cc @@ -43,7 +43,7 @@ using transform::GraphRunner; using transform::GraphRunnerOptions; using transform::OperatorPtr; static std::shared_ptr session = nullptr; -inline ValuePtr PyAttrValue(const py::object& obj) { +inline ValuePtr PyAttrValue(const py::object &obj) { ValuePtr converted_ret = nullptr; bool converted = parse::ConvertData(obj, &converted_ret); if (!converted) { @@ -52,7 +52,7 @@ inline ValuePtr PyAttrValue(const py::object& obj) { return converted_ret; } -MeTensorPtr ConvertPyObjToTensor(const py::object& obj) { +MeTensorPtr ConvertPyObjToTensor(const py::object &obj) { MeTensorPtr me_tensor_ptr = nullptr; if (py::isinstance(obj)) { me_tensor_ptr = py::cast(obj); @@ -72,8 +72,8 @@ MeTensorPtr ConvertPyObjToTensor(const py::object& obj) { return me_tensor_ptr; } -bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const OperatorPtr& op, std::vector* graph_input_nodes) { +bool SetInputsForSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const OperatorPtr &op, std::vector *graph_input_nodes) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(graph_input_nodes); auto op_inputs = op_exec_info->op_inputs; @@ -103,7 +103,7 @@ bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vec auto pointer_cast_const_op = std::static_pointer_cast(const_op); MS_EXCEPTION_IF_NULL(pointer_cast_const_op); (void)pointer_cast_const_op->update_output_desc_y(*const_op_desc); - auto& input_map = adapter->getInputMap(); + auto &input_map = adapter->getInputMap(); if (input_map.find(op_input_idx) == input_map.end()) { continue; } @@ -116,8 +116,8 @@ bool SetInputsForSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vec return true; } -bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const std::unordered_map& attrs, const GeGraphPtr& graph) { +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph) { MS_EXCEPTION_IF_NULL(op_exec_info); std::string op_name = op_exec_info->op_name; auto op_inputs = op_exec_info->op_inputs; @@ -145,8 +145,8 @@ bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vectorsetAttr(op, attr.first, attr.second); } // set input attributes - auto& input_attr_map = adapter->getInputAttrMap(); - for (auto& it : input_attr_map) { + auto &input_attr_map = adapter->getInputAttrMap(); + for (auto &it : input_attr_map) { if (op_inputs.size() < it.first) { continue; } @@ -165,7 +165,7 @@ bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector* const inputs) { +void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector *const inputs) { MS_EXCEPTION_IF_NULL(inputs); MS_EXCEPTION_IF_NULL(op_exec_info); auto op_inputs = op_exec_info->op_inputs; @@ -185,12 +185,12 @@ void ToTensorPtr(const OpExecInfoPtr op_exec_info, std::vector* con } } -PynativeStatusCode ConvertAttributes(const OpExecInfoPtr& op_exec_info, const std::vector& inputs) { +PynativeStatusCode ConvertAttributes(const OpExecInfoPtr &op_exec_info, const std::vector &inputs) { MS_EXCEPTION_IF_NULL(op_exec_info); auto op_attrs = op_exec_info->op_attrs; std::unordered_map attrs{}; - for (auto& item : op_attrs) { + for (auto &item : op_attrs) { if (!py::isinstance(item.first)) { MS_LOG(ERROR) << "Type error in py dict convert"; return PYNATIVE_OP_ATTRS_ERR; @@ -218,8 +218,8 @@ PynativeStatusCode ConvertAttributes(const OpExecInfoPtr& op_exec_info, const st return PYNATIVE_SUCCESS; } -std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, - const std::vector& ge_tensors) { +std::vector ConvertOutputTensors(const OpExecInfoPtr &op_exec_info, + const std::vector &ge_tensors) { std::vector outputs; AbstractBasePtr abs_base = op_exec_info->abstract; std::vector> shapes; @@ -242,7 +242,7 @@ std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, outputs = transform::TransformUtil::ConvertGeTensors(ge_tensors, shapes); return outputs; } - for (auto& it : ge_tensors) { + for (auto &it : ge_tensors) { auto tensor = transform::TransformUtil::ConvertGeTensor(it); if (tensor != nullptr) { outputs.emplace_back(tensor); @@ -251,7 +251,7 @@ std::vector ConvertOutputTensors(const OpExecInfoPtr& op_exec_info, return outputs; } -py::object RunOpInGE(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status) { +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_LOG(INFO) << "RunOpInGe start"; MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(status); diff --git a/mindspore/ccsrc/pynative/pynative_execute_ge.h b/mindspore/ccsrc/pynative/pynative_execute_ge.h index af0efec3e3..2dca3df018 100644 --- a/mindspore/ccsrc/pynative/pynative_execute_ge.h +++ b/mindspore/ccsrc/pynative/pynative_execute_ge.h @@ -36,10 +36,10 @@ using GeGraphPtr = std::shared_ptr; namespace mindspore { namespace pynative { -bool BuildSingleOpGraph(const OpExecInfoPtr& op_exec_info, const std::vector& inputs, - const std::unordered_map& attrs, const GeGraphPtr& graph); +bool BuildSingleOpGraph(const OpExecInfoPtr &op_exec_info, const std::vector &inputs, + const std::unordered_map &attrs, const GeGraphPtr &graph); -py::object RunOpInGE(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* status); +py::object RunOpInGE(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); } // namespace pynative } // namespace mindspore diff --git a/mindspore/ccsrc/transform/convert.h b/mindspore/ccsrc/transform/convert.h index 556db5acee..5596e20f19 100644 --- a/mindspore/ccsrc/transform/convert.h +++ b/mindspore/ccsrc/transform/convert.h @@ -51,16 +51,16 @@ class OpAdapterDesc { public: OpAdapterDesc() : train_(nullptr), infer_(nullptr) {} - OpAdapterDesc(const OpAdapterPtr& train, const OpAdapterPtr& infer) : train_(train), infer_(infer) {} + OpAdapterDesc(const OpAdapterPtr &train, const OpAdapterPtr &infer) : train_(train), infer_(infer) {} - explicit OpAdapterDesc(const OpAdapterPtr& common) : train_(common), infer_(common) {} + explicit OpAdapterDesc(const OpAdapterPtr &common) : train_(common), infer_(common) {} - OpAdapterDesc(const OpAdapterDesc& desc) { + OpAdapterDesc(const OpAdapterDesc &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; } - OpAdapterDesc(OpAdapterDesc&& desc) { + OpAdapterDesc(OpAdapterDesc &&desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; desc.train_ = nullptr; @@ -71,7 +71,7 @@ class OpAdapterDesc { OpAdapterPtr Get(bool train) const { return train ? train_ : infer_; } - OpAdapterDesc& operator=(const OpAdapterDesc& desc) { + OpAdapterDesc &operator=(const OpAdapterDesc &desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; @@ -79,7 +79,7 @@ class OpAdapterDesc { return *this; } - OpAdapterDesc& operator=(OpAdapterDesc&& desc) { + OpAdapterDesc &operator=(OpAdapterDesc &&desc) { if (this != &desc) { this->train_ = desc.train_; this->infer_ = desc.infer_; @@ -99,7 +99,7 @@ using TensorOrderMap = std::map>; class DfGraphConvertor { public: - explicit DfGraphConvertor(const AnfGraphPtr& anf_graph) + explicit DfGraphConvertor(const AnfGraphPtr &anf_graph) : anf_graph_(anf_graph), df_graph_(std::make_shared(anf_graph_->ToString())) { #if (!defined ENABLE_GE) || (defined ENABLE_INFER) auto it_training = anf_graph->flags().find("training"); @@ -125,14 +125,14 @@ class DfGraphConvertor { ~DfGraphConvertor() {} - static void RegisterAdapter(const std::string& name, OpAdapterPtr adpt) { + static void RegisterAdapter(const std::string &name, OpAdapterPtr adpt) { get_adpt_map()[name] = std::make_shared(adpt); } - static void RegisterAdapter(const std::string& name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { + static void RegisterAdapter(const std::string &name, OpAdapterPtr train_adpt, OpAdapterPtr infer_adpt) { get_adpt_map()[name] = std::make_shared(train_adpt, infer_adpt); } - void DrawComputeGraph(const std::string& name) { + void DrawComputeGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -141,7 +141,7 @@ class DfGraphConvertor { fout << compute_sout_.str(); fout.close(); } - void DrawInitGraph(const std::string& name) { + void DrawInitGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -150,7 +150,7 @@ class DfGraphConvertor { fout << init_sout_.str(); fout.close(); } - void DrawSaveCheckpointGraph(const std::string& name) { + void DrawSaveCheckpointGraph(const std::string &name) { std::ofstream fout(name); if (!fout.is_open()) { MS_LOG(ERROR) << "Open file '" << name << "' failed!"; @@ -160,74 +160,74 @@ class DfGraphConvertor { fout.close(); } - DfGraphConvertor& ConvertAllNode(); - DfGraphConvertor& BuildGraph(); - DfGraphConvertor& InitParam(const TensorOrderMap& tensors); - DfGraphConvertor& GenerateCheckpointGraph(); - DfGraphConvertor& GenerateBroadcastGraph(const TensorOrderMap& tensors); - void InitParamWithData(const TensorOrderMap& tensors); - void SetOpInput(const OpAdapterPtr& adpt, const CNodePtr& node); - void SetupBroadcast(const std::shared_ptr& broadcast, const std::vector& broadcast_desc, - const DfGraphPtr& broadcast_graph, std::vector broadcast_input); - void MakeDatasetHandler(const std::string& name, const size_t& input_idx, const AnfNodePtr& it); - void SetupParamInitSubGraph(const TensorOrderMap& tensors, std::vector* init_input); - void DrawParamInitSubGraph(const std::string& name, const AnfNodePtr& it); + DfGraphConvertor &ConvertAllNode(); + DfGraphConvertor &BuildGraph(); + DfGraphConvertor &InitParam(const TensorOrderMap &tensors); + DfGraphConvertor &GenerateCheckpointGraph(); + DfGraphConvertor &GenerateBroadcastGraph(const TensorOrderMap &tensors); + void InitParamWithData(const TensorOrderMap &tensors); + void SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node); + void SetupBroadcast(const std::shared_ptr &broadcast, const std::vector &broadcast_desc, + const DfGraphPtr &broadcast_graph, std::vector broadcast_input); + void MakeDatasetHandler(const std::string &name, const size_t &input_idx, const AnfNodePtr &it); + void SetupParamInitSubGraph(const TensorOrderMap &tensors, std::vector *init_input); + void DrawParamInitSubGraph(const std::string &name, const AnfNodePtr &it); DfGraphPtr GetComputeGraph(); DfGraphPtr GetInitGraph(); DfGraphPtr GetSaveCheckpointGraph(); DfGraphPtr GetBroadcastGraph(); - static OpAdapterPtr FindAdapter(const std::string& op_name, bool train = false); + static OpAdapterPtr FindAdapter(const std::string &op_name, bool train = false); static OpAdapterPtr FindAdapter(AnfNodePtr node, bool train = false); int ErrCode() const { return static_cast(error_); } - static std::unordered_map& get_adpt_map(); + static std::unordered_map &get_adpt_map(); bool is_training() const { return training_; } void set_training(bool is_training) { training_ = is_training; } protected: - void InitLoopVar(std::vector* init_input); + void InitLoopVar(std::vector *init_input); private: std::ostringstream compute_sout_; std::ostringstream init_sout_; std::ostringstream checkpoint_sout_; std::ostringstream restore_checkpoint_sout_; - std::unordered_map op_draw_name_; + std::unordered_map op_draw_name_; - AnfNodePtr TraceTupleGetItem(const CNodePtr& node, unsigned int* index); - AnfNodePtr TraceMakeTuple(const CNodePtr& node, unsigned int index); - AnfNodePtr TraceDepend(const CNodePtr& node); + AnfNodePtr TraceTupleGetItem(const CNodePtr &node, unsigned int *index); + AnfNodePtr TraceMakeTuple(const CNodePtr &node, unsigned int index); + AnfNodePtr TraceDepend(const CNodePtr &node); OutHandler TraceRealOp(AnfNodePtr node); - OutHandler GetHandler(const AnfNodePtr& node, const std::stack& index_stack, AnfNode* const draw_index); + OutHandler GetHandler(const AnfNodePtr &node, const std::stack &index_stack, AnfNode *const draw_index); OperatorPtr Convert(AnfNodePtr node); OperatorPtr ConvertCNode(CNodePtr node); std::vector ConvertDependNode(AnfNodePtr node); AnfNodePtr GetRealOpNode(AnfNodePtr node); - std::vector GetDependNodes(const AnfNodePtr& node); + std::vector GetDependNodes(const AnfNodePtr &node); OperatorPtr ConvertParameter(AnfNodePtr node); Status TryConvertValueNodeToMultiConst(const ValueNodePtr node); OperatorPtr ConvertValueNode(ValueNodePtr node); void ConvertTupleGetItem(const CNodePtr node); - void GetDependOnParameterUse(const CNodePtr& node, const AnfNodePtr& src_node, const AnfNodePtr& dest_node, - const std::shared_ptr>& src_ops_list, - const std::shared_ptr>& dst_ops_list); - bool GetControlDependList(const CNodePtr& node, const std::shared_ptr>& src_ops_list, - const std::shared_ptr>& dst_ops_list); - void DrawControlDepend(const AnfNodePtr& src_node, const AnfNodePtr& dest_node); + void GetDependOnParameterUse(const CNodePtr &node, const AnfNodePtr &src_node, const AnfNodePtr &dest_node, + const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + bool GetControlDependList(const CNodePtr &node, const std::shared_ptr> &src_ops_list, + const std::shared_ptr> &dst_ops_list); + void DrawControlDepend(const AnfNodePtr &src_node, const AnfNodePtr &dest_node); void ConvertControlDependNode(const CNodePtr node); void ConvertMakeTuple(const CNodePtr node); - bool CheckCNode(const std::string& name, const CNodePtr node); + bool CheckCNode(const std::string &name, const CNodePtr node); void TraceOutput(AnfNodePtr node); - void TraceOutputFromParameter(const AnfNodePtr& anf_out); - void TraceOutputFromTupleGetItem(const AnfNodePtr& anf_out); + void TraceOutputFromParameter(const AnfNodePtr &anf_out); + void TraceOutputFromTupleGetItem(const AnfNodePtr &anf_out); void SetNodeInput(AnfNodePtr node); void SetOpControlInput(const AnfNodePtr node); void UpdateOpDesc(AnfNodePtr node); void BuildSaveCheckpointGraph(); void DrawCNode(const CNodePtr node, const OpAdapterPtr adpt); - void UpdateDataOpDesc(const AnfNodePtr& it, const OperatorPtr& op) const; - void AddGraphConstInput(const OperatorPtr& op); + void UpdateDataOpDesc(const AnfNodePtr &it, const OperatorPtr &op) const; + void AddGraphConstInput(const OperatorPtr &op); std::shared_ptr anf_graph_{nullptr}; std::shared_ptr df_graph_{nullptr}; @@ -235,12 +235,12 @@ class DfGraphConvertor { std::shared_ptr save_ckp_graph_{nullptr}; std::shared_ptr restore_ckp_graph_{nullptr}; std::shared_ptr broadcast_graph_{nullptr}; - std::unordered_map op_cache_; - std::unordered_map> control_depend_cache_; + std::unordered_map op_cache_; + std::unordered_map> control_depend_cache_; /* record "tuple_getitem"<->"out_handler" mapping */ - std::unordered_map out_handle_cache_; + std::unordered_map out_handle_cache_; /* record "make_tuple"<->"out_handler vector" mapping */ - std::unordered_map>> tuple_out_handle_cache_; + std::unordered_map>> tuple_out_handle_cache_; std::unordered_map params_; std::unordered_map vars_; std::vector> graph_outputs_; diff --git a/mindspore/ccsrc/transform/df_graph_manager.cc b/mindspore/ccsrc/transform/df_graph_manager.cc index bfe4d9f5d2..f62c386587 100644 --- a/mindspore/ccsrc/transform/df_graph_manager.cc +++ b/mindspore/ccsrc/transform/df_graph_manager.cc @@ -31,8 +31,8 @@ namespace mindspore { namespace transform { -DfGraphWrapper::DfGraphWrapper(const std::string& name, const int& id, const DfGraphPtr& graph_ptr, - const OptionMap& options) +DfGraphWrapper::DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, + const OptionMap &options) : name_(name), id_(id), graph_ptr_(graph_ptr), options_(options) {} DfGraphManager::DfGraphManager() { @@ -49,7 +49,7 @@ DfGraphManager::~DfGraphManager() { parse::python_adapter::set_python_env_flag(false); } -DfGraphManager& DfGraphManager::GetInstance() { +DfGraphManager &DfGraphManager::GetInstance() { static DfGraphManager instance; return instance; } @@ -63,7 +63,7 @@ int DfGraphManager::GenerateId() { return graph_id_; } -Status DfGraphManager::AddGraph(const std::string& name, const DfGraphPtr& graph_ptr, const OptionMap& options) { +Status DfGraphManager::AddGraph(const std::string &name, const DfGraphPtr &graph_ptr, const OptionMap &options) { std::lock_guard lg(lock_); if (name.empty()) { MS_LOG(ERROR) << "The graph name is null, add graph failed"; @@ -101,9 +101,9 @@ std::vector DfGraphManager::GetAllGraphs() { } std::set DfGraphManager::GetSavedGraphs() { return saved_graphs_; } -void DfGraphManager::AddSavedGraphs(const std::string& id) { saved_graphs_.insert(id); } +void DfGraphManager::AddSavedGraphs(const std::string &id) { saved_graphs_.insert(id); } -DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string& name) { +DfGraphWrapperPtr DfGraphManager::GetGraphByName(const std::string &name) { std::lock_guard lg(lock_); if (name.empty()) { MS_LOG(ERROR) << "The graph name is null"; @@ -126,7 +126,7 @@ void DfGraphManager::ClearGraph() noexcept { MS_LOG(INFO) << "Remove all graphs in GraphManager"; } -void DfGraphManager::SetAnfGraph(const std::string& name, const AnfGraphPtr& anf_graph_ptr) { +void DfGraphManager::SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr) { DfGraphWrapperPtr df_graph = GetGraphByName(name); if (df_graph == nullptr) { MS_LOG(ERROR) << "Can't found graph name: " << name; @@ -152,7 +152,7 @@ void DfGraphManager::EraseAnfGraph() { anf_graphs_.clear(); } -void DfGraphManager::SetGeSession(const std::shared_ptr& sess_ptr) { +void DfGraphManager::SetGeSession(const std::shared_ptr &sess_ptr) { std::lock_guard lg(lock_); if (sess_ptr == nullptr) { MS_LOG(WARNING) << "You are adding a empty Ge Session"; @@ -182,7 +182,7 @@ void DfGraphManager::DeleteGeSession() noexcept { } } -void DfGraphManager::SetGraphRunner(const std::shared_ptr& graph_runner_ptr) noexcept { +void DfGraphManager::SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept { std::lock_guard lg(lock_); if (graph_runner_ptr == nullptr) { MS_LOG(WARNING) << "You are adding a empty GraphRunner"; diff --git a/mindspore/ccsrc/transform/df_graph_manager.h b/mindspore/ccsrc/transform/df_graph_manager.h index 97137ae94b..2ca43d1f07 100644 --- a/mindspore/ccsrc/transform/df_graph_manager.h +++ b/mindspore/ccsrc/transform/df_graph_manager.h @@ -35,7 +35,7 @@ using OptionMap = std::map; struct DfGraphWrapper { public: - DfGraphWrapper(const std::string& name, const int& id, const DfGraphPtr& graph_ptr, const OptionMap& options); + DfGraphWrapper(const std::string &name, const int &id, const DfGraphPtr &graph_ptr, const OptionMap &options); ~DfGraphWrapper() {} std::string name_; @@ -51,19 +51,19 @@ class DfGraphManager { ~DfGraphManager(); void ClearGraph() noexcept; - static DfGraphManager& GetInstance(); - Status AddGraph(const std::string& name, const DfGraphPtr& graph, const OptionMap& options = {}); + static DfGraphManager &GetInstance(); + Status AddGraph(const std::string &name, const DfGraphPtr &graph, const OptionMap &options = {}); std::vector GetAllGraphs(); std::set GetSavedGraphs(); - void AddSavedGraphs(const std::string& id); - DfGraphWrapperPtr GetGraphByName(const std::string& name); - DfGraphManager(const DfGraphManager&) = delete; - void SetAnfGraph(const std::string& name, const AnfGraphPtr& anf_graph_ptr); + void AddSavedGraphs(const std::string &id); + DfGraphWrapperPtr GetGraphByName(const std::string &name); + DfGraphManager(const DfGraphManager &) = delete; + void SetAnfGraph(const std::string &name, const AnfGraphPtr &anf_graph_ptr); AnfGraphPtr GetAnfGraph(uint32_t graph_id); std::shared_ptr GetGraphRunner(); - void SetGraphRunner(const std::shared_ptr& graph_runner_ptr) noexcept; + void SetGraphRunner(const std::shared_ptr &graph_runner_ptr) noexcept; void DeleteGraphRunner() noexcept; - void SetGeSession(const std::shared_ptr& sess_ptr); + void SetGeSession(const std::shared_ptr &sess_ptr); std::shared_ptr GetGeSession(); void DeleteGeSession() noexcept; void EraseAnfGraph(); diff --git a/mindspore/ccsrc/transform/graph_builder.cc b/mindspore/ccsrc/transform/graph_builder.cc index 9c05969fb0..785c5c7f3a 100644 --- a/mindspore/ccsrc/transform/graph_builder.cc +++ b/mindspore/ccsrc/transform/graph_builder.cc @@ -21,7 +21,7 @@ namespace mindspore { namespace transform { -DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam& param) { +DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam ¶m) { MS_LOG(INFO) << "BuildMDDatasetGraph."; // InitData @@ -37,7 +37,7 @@ DfGraphPtr BuildMDDatasetGraph(const DatasetGraphParam& param) { return dataset_graph; } -Status BuildDatasetGraph(const DatasetGraphParam& param, const std::string& phase) { +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase) { Status ret; std::string graph_name = phase; diff --git a/mindspore/ccsrc/transform/graph_builder.h b/mindspore/ccsrc/transform/graph_builder.h index 30b891460b..3d959f5a85 100644 --- a/mindspore/ccsrc/transform/graph_builder.h +++ b/mindspore/ccsrc/transform/graph_builder.h @@ -27,7 +27,7 @@ namespace mindspore { namespace transform { -Status BuildDatasetGraph(const DatasetGraphParam& param, const std::string& phase = "dataset"); +Status BuildDatasetGraph(const DatasetGraphParam ¶m, const std::string &phase = "dataset"); } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/graph_runner.cc b/mindspore/ccsrc/transform/graph_runner.cc index 8b0ddfd18d..52d0d8e17f 100644 --- a/mindspore/ccsrc/transform/graph_runner.cc +++ b/mindspore/ccsrc/transform/graph_runner.cc @@ -30,7 +30,7 @@ #ifdef NO_GE_CLIENT namespace ge { -Session::Session(const std::map& options) { +Session::Session(const std::map &options) { if (options.empty()) { MS_LOG(ERROR) << "session input options is empty"; } @@ -42,7 +42,7 @@ Session::~Session() {} namespace mindspore { namespace transform { -std::shared_ptr GraphRunner::NewSession(const SessionOptions& sess_options) { +std::shared_ptr GraphRunner::NewSession(const SessionOptions &sess_options) { std::shared_ptr ret = std::make_shared(sess_options); if (ret == nullptr) { MS_LOG(ERROR) << "Create GE session failed"; @@ -52,7 +52,7 @@ std::shared_ptr GraphRunner::NewSession(const SessionOptions& sess_ return ret; } -GraphRunner::GraphRunner(const GraphRunnerOptions& options) +GraphRunner::GraphRunner(const GraphRunnerOptions &options) : options_(options), graph_manager_(DfGraphManager::GetInstance()) { if (ConfigManager::GetInstance().parallel_strategy() == ParallelStrategy::ONE_DEVICE) { MS_LOG(INFO) << "ME run in ONE_DEVICE strategy mode"; @@ -88,7 +88,7 @@ GraphRunner::GraphRunner(const GraphRunnerOptions& options) } #ifdef ENABLE_GE - for (auto& it : wrappers) { + for (auto &it : wrappers) { std::set saved_graph = graph_manager_.GetSavedGraphs(); auto iter_find = saved_graph.find(std::to_string(it->id_)); if (iter_find != saved_graph.end()) { @@ -101,8 +101,8 @@ GraphRunner::GraphRunner(const GraphRunnerOptions& options) #endif } -Status GraphRunner::RunGraph(const RunOptions& options, const std::vector& inputs, - std::vector* outputs) { +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *outputs) { std::string name = options.name; if (name.empty()) { MS_LOG(ERROR) << "The graph name is null"; @@ -125,7 +125,7 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vector ge_outputs; (void)std::transform(inputs.begin(), inputs.end(), std::back_inserter(ge_inputs), - [](const GeTensorPtr& i) { return *i; }); + [](const GeTensorPtr &i) { return *i; }); MS_LOG(INFO) << "Run the graph in GE with " << ge_inputs.size() << " inputs"; @@ -161,19 +161,19 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vector(ge_tensor); }); + [](const GeTensor &ge_tensor) { return std::make_shared(ge_tensor); }); return Status::SUCCESS; } -Status GraphRunner::RunGraph(const RunOptions& options, const std::vector& inputs, - std::vector* const outputs) { +Status GraphRunner::RunGraph(const RunOptions &options, const std::vector &inputs, + std::vector *const outputs) { std::vector ge_inputs; for (auto it : inputs) { MS_LOG(INFO) << "inputs tensor's data size is: " << (*it).DataSize(); auto shape = (*it).shape(); std::string shape_str; - for (const auto& elem : shape) { + for (const auto &elem : shape) { shape_str += std::to_string(elem); shape_str += " "; } @@ -199,7 +199,7 @@ Status GraphRunner::RunGraph(const RunOptions& options, const std::vectoremplace_back(tensor); diff --git a/mindspore/ccsrc/transform/graph_runner.h b/mindspore/ccsrc/transform/graph_runner.h index a9aa9fbc59..728a1a25a2 100644 --- a/mindspore/ccsrc/transform/graph_runner.h +++ b/mindspore/ccsrc/transform/graph_runner.h @@ -46,16 +46,16 @@ struct RunOptions { class GraphRunner { public: - explicit GraphRunner(const GraphRunnerOptions& options); + explicit GraphRunner(const GraphRunnerOptions &options); ~GraphRunner() { sess_ = nullptr; } - Status RunGraph(const RunOptions& options, const std::vector& inputs, std::vector* outputs); - Status RunGraph(const RunOptions& options, const std::vector& inputs, std::vector* outputs); - static std::shared_ptr NewSession(const SessionOptions& sess_options); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + Status RunGraph(const RunOptions &options, const std::vector &inputs, std::vector *outputs); + static std::shared_ptr NewSession(const SessionOptions &sess_options); private: std::shared_ptr sess_; transform::GraphRunnerOptions options_; - DfGraphManager& graph_manager_; + DfGraphManager &graph_manager_; }; } // namespace transform } // namespace mindspore diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index 421e4c4569..2039dfa7d6 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -26,17 +26,17 @@ #include "utils/utils.h" namespace mindspore { namespace transform { -static uint32_t CustomInferFunc(const Operator&) { return 0; } +static uint32_t CustomInferFunc(const Operator &) { return 0; } template class OpAdapter : public BaseOpAdapter { public: using OpType = T; OpAdapter() {} - explicit OpAdapter(const ExtraAttr& extra_attr) : extra_attr_(extra_attr) {} + explicit OpAdapter(const ExtraAttr &extra_attr) : extra_attr_(extra_attr) {} ~OpAdapter() override {} - bool IsCustomOp(const OperatorPtr& op) { + bool IsCustomOp(const OperatorPtr &op) { MS_EXCEPTION_IF_NULL(op); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { @@ -45,7 +45,7 @@ class OpAdapter : public BaseOpAdapter { return true; } - Status GenerateCustomOpInputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) { + Status GenerateCustomOpInputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(prim); // Create the map of custom op from input index to input name. @@ -69,7 +69,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - Status GenerateCustomOpOutputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) { + Status GenerateCustomOpOutputMap(const CusOperatorPtr &op, const PrimitivePtr &prim) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(prim); // Create the map of custom op from output index to output name. @@ -122,7 +122,7 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr GenerateNormalOp(const AnfNodePtr& anf) { + OperatorPtr GenerateNormalOp(const AnfNodePtr &anf) { OperatorPtr op = nullptr; // There are duplicate names in ANF graph, do not assign ANF node name to GE // GE will generate unique name automatically @@ -148,7 +148,7 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr generate(const AnfNodePtr& anf) override { + OperatorPtr generate(const AnfNodePtr &anf) override { OperatorPtr op = nullptr; if (IsCustomCNode(anf)) { op = GenerateCustomOp(anf); @@ -158,21 +158,21 @@ class OpAdapter : public BaseOpAdapter { return op; } - OperatorPtr generate(const std::string& op_name) override { return std::make_shared(op_name); } + OperatorPtr generate(const std::string &op_name) override { return std::make_shared(op_name); } - const std::unordered_map& getInputMap() override { return input_map_; } - const std::unordered_map& getInputAttrMap() override { return input_attr_map_; } - const std::unordered_map& getDynInputMap() override { return dyn_input_map_; } - const std::unordered_map& getOutputMap() override { return output_map_; } + const std::unordered_map &getInputMap() override { return input_map_; } + const std::unordered_map &getInputAttrMap() override { return input_attr_map_; } + const std::unordered_map &getDynInputMap() override { return dyn_input_map_; } + const std::unordered_map &getOutputMap() override { return output_map_; } - Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OperatorPtr& input) { + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(input); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { return NOT_FOUND; } - std::unordered_map& input_map = it->second; + std::unordered_map &input_map = it->second; if ((input_map.find(index) != input_map.end())) { MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index]; @@ -182,7 +182,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - Status SetNormalOpInput(const OperatorPtr& op, int index, const OperatorPtr& input) { + Status SetNormalOpInput(const OperatorPtr &op, int index, const OperatorPtr &input) { MS_EXCEPTION_IF_NULL(op); auto it = input_map_.find(index); if (it != input_map_.end()) { @@ -194,7 +194,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - int setInput(const OperatorPtr& op, int index, const OperatorPtr& input) override { + int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) override { if (IsCustomOp(op)) { auto cus_op = std::dynamic_pointer_cast(op); return static_cast(SetCustomOpInput(cus_op, index, input)); @@ -203,14 +203,14 @@ class OpAdapter : public BaseOpAdapter { } } - Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OutHandler& handle) { + Status SetCustomOpInput(const CusOperatorPtr &op, int index, const OutHandler &handle) { MS_EXCEPTION_IF_NULL(op); auto it = cus_input_map_.find(op->GetOpType()); if (it == cus_input_map_.end()) { return NOT_FOUND; } - std::unordered_map& input_map = it->second; + std::unordered_map &input_map = it->second; if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) { if (handle.out.empty()) { MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index]; @@ -225,7 +225,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - Status SetNormalOpInput(const OperatorPtr& op, int index, const OutHandler& handle) { + Status SetNormalOpInput(const OperatorPtr &op, int index, const OutHandler &handle) { MS_EXCEPTION_IF_NULL(op); auto it = input_map_.find(index); if ((handle.op != nullptr) && (it != input_map_.end())) { @@ -242,7 +242,7 @@ class OpAdapter : public BaseOpAdapter { return NOT_FOUND; } - int setInput(const OperatorPtr& op, int index, const OutHandler& handle) override { + int setInput(const OperatorPtr &op, int index, const OutHandler &handle) override { if (IsCustomOp(op)) { auto cus_op = std::dynamic_pointer_cast(op); return static_cast(SetCustomOpInput(cus_op, index, handle)); @@ -251,7 +251,7 @@ class OpAdapter : public BaseOpAdapter { } } - int setInput(const OperatorPtr& op, int index, const std::shared_ptr>& handler_vec) override { + int setInput(const OperatorPtr &op, int index, const std::shared_ptr> &handler_vec) override { MS_EXCEPTION_IF_NULL(handler_vec); if (IsCustomOp(op)) { MS_LOG(ERROR) << "Custom Op do not support dynamic input"; @@ -278,7 +278,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(NOT_FOUND); } - OutHandler getOutput(const OperatorPtr& op, int index) override { + OutHandler getOutput(const OperatorPtr &op, int index) override { MS_EXCEPTION_IF_NULL(op); if (IsCustomOp(op)) { return getCustomOutput(op, index); @@ -286,7 +286,7 @@ class OpAdapter : public BaseOpAdapter { return getNormalOutput(op, index); } - OutHandler getCustomOutput(const OperatorPtr& op, int index) { + OutHandler getCustomOutput(const OperatorPtr &op, int index) { MS_EXCEPTION_IF_NULL(op); auto it = cus_output_map_.find(op->GetOpType()); if (it == cus_output_map_.end()) { @@ -294,7 +294,7 @@ class OpAdapter : public BaseOpAdapter { return OutHandler(); } - std::unordered_map& output_map = it->second; + std::unordered_map &output_map = it->second; if ((output_map.find(index) != output_map.end())) { return OutHandler(op, output_map[index]); @@ -303,7 +303,7 @@ class OpAdapter : public BaseOpAdapter { return OutHandler(); } - OutHandler getNormalOutput(const OperatorPtr& op, int index) { + OutHandler getNormalOutput(const OperatorPtr &op, int index) { MS_EXCEPTION_IF_NULL(op); if (!dyn_output_map_.empty() && !output_map_.empty()) { MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!"; @@ -320,7 +320,7 @@ class OpAdapter : public BaseOpAdapter { } } - Status UpdateSingleOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) { + Status UpdateSingleOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { MS_EXCEPTION_IF_NULL(type); std::string format = "NCHW"; if (op->GetOpType() == kExtractImagePatchesOpName) { @@ -353,7 +353,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - size_t GetCustomOpOutputSize(const CusOperatorPtr& cus_op) { + size_t GetCustomOpOutputSize(const CusOperatorPtr &cus_op) { MS_EXCEPTION_IF_NULL(cus_op); if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) { MS_LOG(ERROR) << "This op does not create custom output map"; @@ -363,8 +363,8 @@ class OpAdapter : public BaseOpAdapter { return output_size; } - std::shared_ptr CreateOutputDesc(const abstract::ShapePtr& shape_ptr, const TypePtr& type, - const std::string& format) { + std::shared_ptr CreateOutputDesc(const abstract::ShapePtr &shape_ptr, const TypePtr &type, + const std::string &format) { if (shape_ptr == nullptr) { MS_LOG(ERROR) << "Shape ptr is nullptr"; return nullptr; @@ -383,7 +383,7 @@ class OpAdapter : public BaseOpAdapter { return desc; } - Status UpdateMultiOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) { + Status UpdateMultiOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type) { auto tuple_shp = dyn_cast(shp); MS_EXCEPTION_IF_NULL(tuple_shp); @@ -432,7 +432,7 @@ class OpAdapter : public BaseOpAdapter { return SUCCESS; } - std::shared_ptr CreateNodeDesc(const AnfNodePtr& node) { + std::shared_ptr CreateNodeDesc(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); TypeId me_type = node->Type()->type_id(); if (kObjectTypeTensorType == me_type) { @@ -456,7 +456,7 @@ class OpAdapter : public BaseOpAdapter { return desc; } - void UpdateNormalOpInputDesc(const OperatorPtr& op, const AnfNodePtr node) { + void UpdateNormalOpInputDesc(const OperatorPtr &op, const AnfNodePtr node) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -479,7 +479,7 @@ class OpAdapter : public BaseOpAdapter { } } - void UpdateCustomOpInputDesc(const CusOperatorPtr& op, const AnfNodePtr& node) { + void UpdateCustomOpInputDesc(const CusOperatorPtr &op, const AnfNodePtr &node) { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -491,7 +491,7 @@ class OpAdapter : public BaseOpAdapter { return; } - std::unordered_map& input_map = cus_input_map_[op->GetOpType()]; + std::unordered_map &input_map = cus_input_map_[op->GetOpType()]; auto inputs = node->cast()->inputs(); for (size_t i = 1; i < inputs.size(); ++i) { if (input_map.find(i) != input_map.end()) { @@ -504,7 +504,7 @@ class OpAdapter : public BaseOpAdapter { } } - void updateInputDesc(const OperatorPtr& op, const AnfNodePtr& node) { + void updateInputDesc(const OperatorPtr &op, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(op); MS_EXCEPTION_IF_NULL(node); if (IsCustomOp(op)) { @@ -515,8 +515,8 @@ class OpAdapter : public BaseOpAdapter { } } - void updateOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type, - const AnfNodePtr& node) override { + void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) override { if (op == nullptr) { MS_LOG(ERROR) << "op is nullptr"; return; @@ -548,7 +548,7 @@ class OpAdapter : public BaseOpAdapter { updateInputDesc(op, node); } - int setAttr(const OperatorPtr& op, const std::string& attrKey, const ValuePtr& attrValue) override { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) override { auto it = attr_map_.find(attrKey); if (it != attr_map_.end()) { // switch case for each avalilable attribute type @@ -560,7 +560,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(NOT_FOUND); } - int SetCustomOpAttr(const CusOperatorPtr& op, const PrimitivePtr& prim) { + int SetCustomOpAttr(const CusOperatorPtr &op, const PrimitivePtr &prim) { enum ValueType { SINGLE_VALUE = 0, SEQUEUE_VALUE, @@ -611,11 +611,11 @@ class OpAdapter : public BaseOpAdapter { return 0; } - int SetNormalOpAttr(const OperatorPtr& op, const PrimitivePtr& prim) { + int SetNormalOpAttr(const OperatorPtr &op, const PrimitivePtr &prim) { int ret = 0; MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(op); - for (auto& it : attr_map_) { + for (auto &it : attr_map_) { auto value = prim->GetAttr(it.first); if (value != nullptr) { // set attr from primitive @@ -637,7 +637,7 @@ class OpAdapter : public BaseOpAdapter { return 0; } - int setAttr(const OperatorPtr& op, const PrimitivePtr& prim) override { + int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) override { int ret = 0; if (IsCustomPrim(prim)) { auto cus_op = std::dynamic_pointer_cast(op); @@ -648,7 +648,7 @@ class OpAdapter : public BaseOpAdapter { return ret; } - int setAttr(const OperatorPtr& op, const AnfNodePtr& node) override { + int setAttr(const OperatorPtr &op, const AnfNodePtr &node) override { // no attribute for lonely node MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { @@ -660,7 +660,7 @@ class OpAdapter : public BaseOpAdapter { return 0; } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); if (inputs.empty()) { return 0; } @@ -691,7 +691,7 @@ class OpAdapter : public BaseOpAdapter { } // set attr from const input - for (auto& it : input_attr_map_) { + for (auto &it : input_attr_map_) { if (inputs.size() <= it.first || !inputs[it.first]->isa()) { continue; } @@ -711,38 +711,38 @@ class OpAdapter : public BaseOpAdapter { private: template - static S ConvertAny(const ValuePtr& value, const AnyTraits&) { + static S ConvertAny(const ValuePtr &value, const AnyTraits &) { return GetValue(value); } // specialization for reverse bool - static bool ConvertAny(const ValuePtr& value, const AnyTraits&, bool reverse) { + static bool ConvertAny(const ValuePtr &value, const AnyTraits &, bool reverse) { return reverse != GetValue(value); } template - static Q ConvertAny(const ValuePtr& value, const AnyTraits

& traits_from, const AnyTraits& traits_to) { + static Q ConvertAny(const ValuePtr &value, const AnyTraits

&traits_from, const AnyTraits &traits_to) { return ConvertAnyUtil(value, traits_from, traits_to); } // specialization for tensor - static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits& traits) { + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits &traits) { // To-DO the format may read from ME tensor return ConvertAnyUtil(value, traits); } // specialization for int - static int64_t ConvertAny(const ValuePtr& value, const AnyTraits) { + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { return static_cast(GetValue(value)); } // specialization for int to Vector - static std::vector ConvertAny(const ValuePtr& value, const std::string& name, + static std::vector ConvertAny(const ValuePtr &value, const std::string &name, const AnyTraits> anyTraitsInt) { return ConvertAnyUtil(value, name, anyTraitsInt); } - static std::vector> ConvertAny(const ValuePtr& value, + static std::vector> ConvertAny(const ValuePtr &value, const AnyTraits>>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(INFO) << "Value: " << value->type_name(); @@ -752,14 +752,14 @@ class OpAdapter : public BaseOpAdapter { } auto vec = value->cast(); MS_EXCEPTION_IF_NULL(vec); - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { MS_EXCEPTION_IF_NULL(it); if (!it->isa()) { MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name(); } auto sub_vector = it->cast(); std::vector sublist; - for (auto& item : sub_vector->value()) { + for (auto &item : sub_vector->value()) { sublist.push_back(static_cast(GetValue(item))); } list.push_back(sublist); @@ -767,7 +767,7 @@ class OpAdapter : public BaseOpAdapter { return list; } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits>>, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>>, const AnyTraits>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(DEBUG) << "Value: " << value->type_name(); @@ -776,20 +776,20 @@ class OpAdapter : public BaseOpAdapter { } auto vec = value->cast(); std::vector list; - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { MS_EXCEPTION_IF_NULL(it); if (!it->isa()) { MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name(); } auto sub_vector = it->cast(); - for (auto& item : sub_vector->value()) { + for (auto &item : sub_vector->value()) { list.push_back(static_cast(GetValue(item))); } } return list; } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits>, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits>, const AnyTraits>) { MS_EXCEPTION_IF_NULL(value); MS_LOG(INFO) << "Value: " << value->type_name(); @@ -797,7 +797,7 @@ class OpAdapter : public BaseOpAdapter { if (value->isa()) { auto vec = value->cast(); MS_EXCEPTION_IF_NULL(vec); - for (auto& it : vec->value()) { + for (auto &it : vec->value()) { list.push_back(static_cast(GetValue(it))); } return list; @@ -809,17 +809,17 @@ class OpAdapter : public BaseOpAdapter { MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name(); } - static std::string ConvertAny(const ValuePtr& value, const AnyTraits> anyTraitsVec, + static std::string ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsStr) { return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr); } - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits> anyTraitsVec, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsFlo) { return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo); } - static std::vector ConvertAny(const ValuePtr& value, const std::string& format, + static std::vector ConvertAny(const ValuePtr &value, const std::string &format, const AnyTraits> anyTraitsVec, const AnyTraits anyTraitsInt) { return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt); @@ -827,12 +827,12 @@ class OpAdapter : public BaseOpAdapter { // convert value list for value tuple to vector template - static std::vector ConvertAny(const ValuePtr& value, const AnyTraits

& anyTraitsP, + static std::vector ConvertAny(const ValuePtr &value, const AnyTraits

&anyTraitsP, const AnyTraits> anyTraitsQ) { return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ); } - static int64_t ConvertAny(const ValuePtr& value, const AnyTraits) { + static int64_t ConvertAny(const ValuePtr &value, const AnyTraits) { auto name = GetValue(value); auto it = enum_map_.find(name); int v = 0; @@ -842,12 +842,12 @@ class OpAdapter : public BaseOpAdapter { return v; } - static GeDataType ConvertAny(const ValuePtr& value, const AnyTraits anyTraitsGE) { + static GeDataType ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsGE) { return ConvertAnyUtil(value, anyTraitsGE); } // convert any value to tensor - static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits anyTraitsValue) { + static GeTensor ConvertAny(const ValuePtr &value, const AnyTraits anyTraitsValue) { return ConvertAnyUtil(value, anyTraitsValue); } diff --git a/mindspore/ccsrc/transform/op_adapter_base.h b/mindspore/ccsrc/transform/op_adapter_base.h index 99106b8761..01f96e251d 100644 --- a/mindspore/ccsrc/transform/op_adapter_base.h +++ b/mindspore/ccsrc/transform/op_adapter_base.h @@ -48,15 +48,17 @@ namespace ge { class CustomOperator : public Operator { public: - CustomOperator(const string& name, const string& type) : Operator(name, type) {} + CustomOperator(const string &name, const string &type) : Operator(name, type) {} ~CustomOperator() override{}; - void CustomInputRegister(const string& name) { Operator::InputRegister(name); } + void CustomInputRegister(const string &name) { Operator::InputRegister(name); } - void CustomOutputRegister(const string& name) { Operator::OutputRegister(name); } + void CustomOutputRegister(const string &name) { Operator::OutputRegister(name); } - void CustomInferFuncRegister(const std::function& func) { Operator::InferFuncRegister(func); } + void CustomInferFuncRegister(const std::function &func) { + Operator::InferFuncRegister(func); + } }; } // namespace ge @@ -69,7 +71,7 @@ struct OutHandler { OperatorPtr op; std::string out; OutHandler() : op(nullptr), out("") {} - OutHandler(const OperatorPtr& op, const std::string out) : op(op), out(out) {} + OutHandler(const OperatorPtr &op, const std::string out) : op(op), out(out) {} }; struct ControlEdge { @@ -119,33 +121,33 @@ struct DynOutputDesc { class BaseOpAdapter { public: virtual ~BaseOpAdapter() {} - virtual OperatorPtr generate(const AnfNodePtr& anf) = 0; - virtual OperatorPtr generate(const std::string& type) { return std::make_shared(type); } - virtual int setInput(const OperatorPtr& op, int index, const OperatorPtr& input) = 0; - virtual int setInput(const OperatorPtr& op, int index, const OutHandler& handle) = 0; - virtual int setInput(const OperatorPtr& op, int index, - const std::shared_ptr>& handler_vec) = 0; - virtual int setAttr(const OperatorPtr& op, const std::string& attrKey, const ValuePtr& attrValue) = 0; - virtual int setAttr(const OperatorPtr& op, const PrimitivePtr& prim) = 0; - virtual int setAttr(const OperatorPtr& op, const AnfNodePtr& node) = 0; + virtual OperatorPtr generate(const AnfNodePtr &anf) = 0; + virtual OperatorPtr generate(const std::string &type) { return std::make_shared(type); } + virtual int setInput(const OperatorPtr &op, int index, const OperatorPtr &input) = 0; + virtual int setInput(const OperatorPtr &op, int index, const OutHandler &handle) = 0; + virtual int setInput(const OperatorPtr &op, int index, + const std::shared_ptr> &handler_vec) = 0; + virtual int setAttr(const OperatorPtr &op, const std::string &attrKey, const ValuePtr &attrValue) = 0; + virtual int setAttr(const OperatorPtr &op, const PrimitivePtr &prim) = 0; + virtual int setAttr(const OperatorPtr &op, const AnfNodePtr &node) = 0; virtual std::unordered_map GetExtraAttr() = 0; template ::value>::type> - int setAttr(const OperatorPtr& op, const std::string& attrKey, const std::shared_ptr& attrValue) { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const std::shared_ptr &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } template ::value>::type> - int setAttr(const OperatorPtr& op, const std::string& attrKey, const T& attrValue) { + int setAttr(const OperatorPtr &op, const std::string &attrKey, const T &attrValue) { return setAttr(op, attrKey, MakeValue(attrValue)); } - virtual OutHandler getOutput(const OperatorPtr& op, int index) = 0; - virtual void updateOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type, - const AnfNodePtr& node) = 0; - virtual const std::unordered_map& getInputMap() = 0; - virtual const std::unordered_map& getInputAttrMap() = 0; - virtual const std::unordered_map& getDynInputMap() = 0; - virtual const std::unordered_map& getOutputMap() = 0; - void AddAttrToDrawGraph(const std::string& attr_str) { attrs_vec_.push_back(attr_str); } - const std::vector& GetAttrsFromDrawGraph() const { return attrs_vec_; } + virtual OutHandler getOutput(const OperatorPtr &op, int index) = 0; + virtual void updateOutputDesc(const OperatorPtr &op, const abstract::BaseShapePtr &shp, const TypePtr &type, + const AnfNodePtr &node) = 0; + virtual const std::unordered_map &getInputMap() = 0; + virtual const std::unordered_map &getInputAttrMap() = 0; + virtual const std::unordered_map &getDynInputMap() = 0; + virtual const std::unordered_map &getOutputMap() = 0; + void AddAttrToDrawGraph(const std::string &attr_str) { attrs_vec_.push_back(attr_str); } + const std::vector &GetAttrsFromDrawGraph() const { return attrs_vec_; } void clearAttrVect() { attrs_vec_.clear(); } private: diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index d52699fa8f..0163b80f08 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -25,7 +25,7 @@ namespace mindspore { namespace transform { -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits&) { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &) { // To-DO the format may read from ME tensor MS_EXCEPTION_IF_NULL(value); auto me_tensor = value->cast(); @@ -33,7 +33,7 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits ConvertAnyUtil(const ValuePtr& value, const std::string& name, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>) { int64_t data = GetValue(value); std::vector list; @@ -50,7 +50,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& na return list; } -std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits) { +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); if (nullptr == vec) { @@ -58,7 +58,7 @@ std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraitsvalue()) { + for (auto &it : vec->value()) { if (i != 0) { buffer << ","; } @@ -68,7 +68,7 @@ std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits) { +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); if (nullptr == vec) { @@ -77,11 +77,11 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const AnyTraits list; list.resize(vec->value().size()); (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr& val) { return static_cast(GetValue(val)); }); + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); return list; } -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& format, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, const AnyTraits>, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); auto vec = value->cast(); @@ -91,7 +91,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& fo std::vector list; list.resize(vec->value().size()); (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(), - [](const ValuePtr& val) { return static_cast(GetValue(val)); }); + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); if (format == kOpFormat_NHWC) { if (list.size() < 4) { MS_LOG(EXCEPTION) << "The size of list is less than 4"; @@ -105,7 +105,7 @@ std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& fo return list; } -GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); if (!value->isa()) { MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString() @@ -120,7 +120,7 @@ GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { return TransformUtil::ConvertDataType(me_type); } -GeTensor VectorToTensorUtil(const ValuePtr& value) { +GeTensor VectorToTensorUtil(const ValuePtr &value) { // convert tuple or list to ge tensor, only supported one dim for now MS_EXCEPTION_IF_NULL(value); auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); @@ -136,7 +136,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(int32_t)); } else if (vec[0]->isa()) { MS_LOG(INFO) << "convert value to tensor with data type = Float32"; auto data = ConvertAnyUtil(value, AnyTraits(), AnyTraits>()); @@ -144,7 +144,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); + return GeTensor(*desc, reinterpret_cast(data.data()), data.size() * sizeof(float)); } else if (vec[0]->isa()) { MS_LOG(INFO) << "convert value to tensor with data type = Bool"; // We use uint8_t to save bool type data @@ -153,7 +153,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { if (desc == nullptr) { MS_LOG(EXCEPTION) << "Update conversion descriptor failed!"; } - return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); + return GeTensor(*desc, static_cast(data.data()), data.size() * sizeof(uint8_t)); } else { MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name(); } @@ -161,7 +161,7 @@ GeTensor VectorToTensorUtil(const ValuePtr& value) { return GeTensor(); } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits) { MS_EXCEPTION_IF_NULL(value); if (value->isa()) { // convert me tensor to ge tensor @@ -174,28 +174,28 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int32_t)); } else if (value->isa()) { // convert scalar Int64 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = Int64"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(int64_t)); } else if (value->isa()) { // convert scalar FP32 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = FP32"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(float)); } else if (value->isa()) { // convert scalar FP32 to GeTensor MS_LOG(INFO) << "convert scalar to tensor with data type = Bool"; GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL); auto v = GetValue(value); desc.SetRealDimCnt(0); - return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); + return GeTensor(desc, reinterpret_cast(&v), sizeof(bool)); } else if (value->isa()) { // convert String to GeTensor MS_LOG(INFO) << "convert string to tensor with data type = String"; @@ -213,7 +213,7 @@ GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits) { return GeTensor(); } -bool IsCustomPrim(const PrimitivePtr& prim) { +bool IsCustomPrim(const PrimitivePtr &prim) { if (prim == nullptr) { return false; } @@ -232,7 +232,7 @@ bool IsCustomPrim(const PrimitivePtr& prim) { return is_custom_op; } -bool IsCustomCNode(const AnfNodePtr& anf) { +bool IsCustomCNode(const AnfNodePtr &anf) { if (anf == nullptr) { return false; } diff --git a/mindspore/ccsrc/transform/op_adapter_util.h b/mindspore/ccsrc/transform/op_adapter_util.h index 0cb6c763b2..fcabc732d5 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.h +++ b/mindspore/ccsrc/transform/op_adapter_util.h @@ -25,42 +25,42 @@ namespace mindspore { namespace transform { template -static Q ConvertAnyUtil(const ValuePtr& value, const AnyTraits

&, const AnyTraits&) { +static Q ConvertAnyUtil(const ValuePtr &value, const AnyTraits

&, const AnyTraits &) { return static_cast(GetValue

(value)); } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits& traits); +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits &traits); -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& name, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>); -std::string ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits); +std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); -std::vector ConvertAnyUtil(const ValuePtr& value, const AnyTraits>, const AnyTraits); +std::vector ConvertAnyUtil(const ValuePtr &value, const AnyTraits>, const AnyTraits); -std::vector ConvertAnyUtil(const ValuePtr& value, const std::string& format, +std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &format, const AnyTraits>, const AnyTraits); -GeDataType ConvertAnyUtil(const ValuePtr& value, const AnyTraits); +GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits); template -std::vector ConvertAnyUtil(const ValuePtr& value, AnyTraits

, const AnyTraits>) { +std::vector ConvertAnyUtil(const ValuePtr &value, AnyTraits

, const AnyTraits>) { if (!value->isa() && !value->isa()) { MS_LOG(EXCEPTION) << "error convert Value to vector for value: " << value->ToString() << ", type: " << value->type_name() << ", value should be a tuple or list"; } auto vec = value->isa() ? value->cast()->value() : value->cast()->value(); std::vector data; - for (auto& it : vec) { + for (auto &it : vec) { data.push_back(ConvertAnyUtil(it, AnyTraits

(), AnyTraits())); } return data; } -GeTensor ConvertAnyUtil(const ValuePtr& value, const AnyTraits); +GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits); -bool IsCustomPrim(const PrimitivePtr& prim); -bool IsCustomCNode(const AnfNodePtr& node); +bool IsCustomPrim(const PrimitivePtr &prim); +bool IsCustomCNode(const AnfNodePtr &node); } // namespace transform } // namespace mindspore #endif // TRANSFORM_OP_ADAPTER_UTIL_H_ diff --git a/mindspore/ccsrc/transform/util.cc b/mindspore/ccsrc/transform/util.cc index 0a18763d12..b1120ade6d 100644 --- a/mindspore/ccsrc/transform/util.cc +++ b/mindspore/ccsrc/transform/util.cc @@ -53,7 +53,7 @@ static std::map datatype_trans_map = { {MeDataType::kNumberTypeUInt16, GeDataType::DT_UINT16}, {MeDataType::kNumberTypeUInt32, GeDataType::DT_UINT32}, {MeDataType::kNumberTypeUInt64, GeDataType::DT_UINT64}, {MeDataType::kNumberTypeBool, GeDataType::DT_BOOL}}; -GeDataType TransformUtil::ConvertDataType(const MeDataType& type) { +GeDataType TransformUtil::ConvertDataType(const MeDataType &type) { MS_LOG(DEBUG) << "Convert me data type: " << TypeIdLabel(type) << " to ge data type"; if (datatype_trans_map.find(type) != datatype_trans_map.end()) { return datatype_trans_map[type]; @@ -70,7 +70,7 @@ static std::map datatype_size_map = { {MeDataType::kNumberTypeUInt16, sizeof(uint16_t)}, {MeDataType::kNumberTypeUInt32, sizeof(uint32_t)}, {MeDataType::kNumberTypeUInt64, sizeof(uint64_t)}, {MeDataType::kNumberTypeBool, sizeof(bool)}}; -size_t TransformUtil::GetDataTypeSize(const MeDataType& type) { +size_t TransformUtil::GetDataTypeSize(const MeDataType &type) { if (datatype_size_map.find(type) != datatype_size_map.end()) { return datatype_size_map[type]; } else { @@ -79,7 +79,7 @@ size_t TransformUtil::GetDataTypeSize(const MeDataType& type) { } } -GeFormat TransformUtil::ConvertFormat(const string& format) { +GeFormat TransformUtil::ConvertFormat(const string &format) { if (format == kOpFormat_NCHW) { return GeFormat::FORMAT_NCHW; } else if (format == kOpFormat_NC1HWC0) { @@ -95,8 +95,8 @@ GeFormat TransformUtil::ConvertFormat(const string& format) { static int64_t IntegerCastFunc(size_t temp) { return static_cast(temp); } -std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector& me_shape, - const MeDataType& me_type, const std::string& format) { +std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector &me_shape, + const MeDataType &me_type, const std::string &format) { // convert me shape to ge shape std::vector ge_shape; @@ -135,8 +135,8 @@ std::shared_ptr TransformUtil::GetGeTensorDesc(const std::vector TransformUtil::ConvertInputTensors(const std::vector& me_tensors, - const std::string& format) { +std::vector TransformUtil::ConvertInputTensors(const std::vector &me_tensors, + const std::string &format) { std::vector ge_tensors; for (size_t index = 0; index < me_tensors.size(); index++) { @@ -163,7 +163,7 @@ std::vector TransformUtil::ConvertInputTensors(const std::vectordata_type()); @@ -192,15 +192,15 @@ GeTensorPtr TransformUtil::ConvertTensor(const MeTensorPtr& tensor, const std::s MS_LOG(ERROR) << "Failed to get Tensor Desc"; return nullptr; } - GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); + GeTensorPtr tensor_ptr = make_shared(*desc, static_cast(tensor->data_c()), data_buff_size); if (tensor_ptr != nullptr) { MS_LOG(INFO) << "Convert Me Tensor to Ge Tensor success!"; } return tensor_ptr; } -std::vector TransformUtil::ConvertGeTensors(const std::vector& ge_tensors, - const std::vector>& request_dims) { +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims) { std::vector outputs; for (size_t index = 0; index < ge_tensors.size(); index++) { @@ -222,7 +222,7 @@ std::vector TransformUtil::ConvertGeTensors(const std::vector TransformUtil::ConvertGeTensors(const std::vector& ge_tensors) { +std::vector TransformUtil::ConvertGeTensors(const std::vector &ge_tensors) { std::vector outputs; for (size_t index = 0; index < ge_tensors.size(); index++) { @@ -237,7 +237,7 @@ std::vector TransformUtil::ConvertGeTensors(const std::vector& request_dims) { +bool IsGeShapeCompatible(const GeShape &ge_shape, const std::vector &request_dims) { MS_LOG(INFO) << "GeTensor's shape is " << TransformUtil::PrintVector(ge_shape.GetDims()); MS_LOG(INFO) << "Me request shape is " << TransformUtil::PrintVector(request_dims); @@ -311,20 +311,20 @@ bool IsGeShapeCompatible(const GeShape& ge_shape, const std::vector& reques } } // namespace -GeShape TransformUtil::ConvertMeShape(const std::vector& me_dims) { +GeShape TransformUtil::ConvertMeShape(const std::vector &me_dims) { std::vector ge_dims; (void)std::copy(me_dims.begin(), me_dims.end(), std::back_inserter(ge_dims)); return GeShape(ge_dims); } -std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape) { +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape) { std::vector me_dims; std::vector ge_dims = ge_shape.GetDims(); (void)std::copy(ge_dims.begin(), ge_dims.end(), std::back_inserter(me_dims)); return me_dims; } -std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape, const std::vector& request_dims) { +std::vector TransformUtil::ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims) { vector ret; if (ge_shape.GetDimNum() == 0) { MS_LOG(DEBUG) << "GeTensor's shape is scalar"; @@ -340,12 +340,12 @@ std::vector TransformUtil::ConvertGeShape(const GeShape& ge_shape, const st return ret; } -MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr& ge_tensor, const std::vector& me_dims, - const TypeId& me_type) { +MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type) { MeTensor me_tensor(me_type, me_dims); // Get the writable data pointer of the tensor and cast it to its data type - auto me_data_ptr = reinterpret_cast(me_tensor.data_c(true)); + auto me_data_ptr = reinterpret_cast(me_tensor.data_c(true)); size_t me_data_size = static_cast(me_tensor.data().nbytes()); MS_EXCEPTION_IF_NULL(me_data_ptr); MS_EXCEPTION_IF_NULL(ge_tensor); @@ -369,7 +369,7 @@ MeTensorPtr TransformUtil::GenerateMeTensor(const GeTensorPtr& ge_tensor, const return make_shared(me_tensor); } -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr& ge_tensor) { +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr &ge_tensor) { MS_EXCEPTION_IF_NULL(ge_tensor); GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); vector me_dims = ConvertGeShape(ge_shape); @@ -384,7 +384,7 @@ MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr& ge_tensor) { } // if request_dims is empty, use ge tensor's shape,otherwise convert to request shape -MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector& request_dims) { +MeTensorPtr TransformUtil::ConvertGeTensor(const GeTensorPtr ge_tensor, const std::vector &request_dims) { MS_EXCEPTION_IF_NULL(ge_tensor); GeShape ge_shape = ge_tensor->GetTensorDesc().GetShape(); vector me_dims = ConvertGeShape(ge_shape, request_dims); diff --git a/mindspore/ccsrc/transform/util.h b/mindspore/ccsrc/transform/util.h index 9bcd8dc115..0f5d79f6a1 100644 --- a/mindspore/ccsrc/transform/util.h +++ b/mindspore/ccsrc/transform/util.h @@ -47,7 +47,7 @@ class TransformUtil { * Return: * [GeDataType] the data type for ge tensor * */ - static GeDataType ConvertDataType(const MeDataType& type); + static GeDataType ConvertDataType(const MeDataType &type); /* * Parameters: @@ -55,7 +55,7 @@ class TransformUtil { * Return: * [GeFormat] the data format for ge tensor * */ - static GeFormat ConvertFormat(const std::string& format); + static GeFormat ConvertFormat(const std::string &format); /* * Parameters: @@ -63,7 +63,7 @@ class TransformUtil { * Return: * [size_t] the buff size for the type in ME * */ - static size_t GetDataTypeSize(const MeDataType& type); + static size_t GetDataTypeSize(const MeDataType &type); /* * Parameters: @@ -73,8 +73,8 @@ class TransformUtil { * Return: * [shared_ptr] the shared pointer of ge tensor description * */ - static std::shared_ptr GetGeTensorDesc(const std::vector& shape, const MeDataType& me_type, - const std::string& format); + static std::shared_ptr GetGeTensorDesc(const std::vector &shape, const MeDataType &me_type, + const std::string &format); /* * Parameters: @@ -84,7 +84,7 @@ class TransformUtil { * Return: * [GeTensor] the data tensor in GE * */ - static GeTensorPtr ConvertTensor(const MeTensorPtr& tensor, const std::string& format); + static GeTensorPtr ConvertTensor(const MeTensorPtr &tensor, const std::string &format); /* * Parameters: @@ -93,8 +93,8 @@ class TransformUtil { * Return: * [std::vector] the data tensors in GE * */ - static std::vector ConvertInputTensors(const std::vector& me_tensors, - const std::string& format); + static std::vector ConvertInputTensors(const std::vector &me_tensors, + const std::string &format); /* * Parameters: @@ -102,7 +102,7 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr ConvertGeTensor(const GeTensorPtr& tensor); + static MeTensorPtr ConvertGeTensor(const GeTensorPtr &tensor); /* * Parameters: @@ -111,7 +111,7 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector& request_dims); + static MeTensorPtr ConvertGeTensor(GeTensorPtr ge_tensor, const std::vector &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE @@ -119,15 +119,15 @@ class TransformUtil { * Return: * [std::vector] the data tensor in ME * */ - static std::vector ConvertGeTensors(const std::vector& ge_tensors, - const std::vector>& request_dims); + static std::vector ConvertGeTensors(const std::vector &ge_tensors, + const std::vector> &request_dims); /* * Parameters: * ge_tensors: [std::vector] the data tensor in GE * Return: * [std::vector] the data tensor in ME * */ - static std::vector ConvertGeTensors(const std::vector& ge_tensors); + static std::vector ConvertGeTensors(const std::vector &ge_tensors); /* * Parameters: * ge_tensor: [GeTensor] the data tensor in GE @@ -136,15 +136,15 @@ class TransformUtil { * Return: * [MeTensor] the data tensor in ME * */ - static MeTensorPtr GenerateMeTensor(const GeTensorPtr& ge_tensor, const std::vector& me_dims, - const TypeId& me_type); + static MeTensorPtr GenerateMeTensor(const GeTensorPtr &ge_tensor, const std::vector &me_dims, + const TypeId &me_type); /* * Parameters: * type: [GeDataType] the ge tensor data type * Return: * [MeDataType] the me tensor data type * */ - static MeDataType ConvertGeDataType(const GeDataType& type); + static MeDataType ConvertGeDataType(const GeDataType &type); /* * Parameters: @@ -152,7 +152,7 @@ class TransformUtil { * Return: * [GeShape] the ge shape * */ - static GeShape ConvertMeShape(const std::vector& me_dims); + static GeShape ConvertMeShape(const std::vector &me_dims); /* * Parameters: @@ -160,7 +160,7 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape& ge_shape); + static std::vector ConvertGeShape(const GeShape &ge_shape); /* Function: * Convert GeShape to Me request shape, Support pattern: @@ -176,7 +176,7 @@ class TransformUtil { * Return: * [vector] the me shape * */ - static std::vector ConvertGeShape(const GeShape& ge_shape, const std::vector& request_dims); + static std::vector ConvertGeShape(const GeShape &ge_shape, const std::vector &request_dims); /* * Parameters: @@ -185,7 +185,7 @@ class TransformUtil { * [string] value string * */ template ::value>::type> - static std::string PrintVector(const std::vector& vec) { + static std::string PrintVector(const std::vector &vec) { const int MAX_PRINT_NUM = 100; std::stringstream ss; ss << "{ "; @@ -222,7 +222,7 @@ class TransformUtil { * [shared_ptr] vector pointer * */ template ::value>::type> - static std::vector MakeVector(const uint8_t* const data, size_t size) { + static std::vector MakeVector(const uint8_t *const data, size_t size) { auto dest = std::vector(size / sizeof(T)); if (data == nullptr) { return dest; diff --git a/mindspore/ccsrc/utils/any.cc b/mindspore/ccsrc/utils/any.cc index 31ee1fd302..3cb89f5dd7 100644 --- a/mindspore/ccsrc/utils/any.cc +++ b/mindspore/ccsrc/utils/any.cc @@ -21,7 +21,7 @@ namespace mindspore { // only support (int, float, bool) as Literal -bool AnyIsLiteral(const Any& any) { +bool AnyIsLiteral(const Any &any) { static const std::type_index typeid_int = std::type_index(typeid(int)); static const std::type_index typeid_float = std::type_index(typeid(float)); static const std::type_index typeid_bool = std::type_index(typeid(bool)); @@ -30,12 +30,12 @@ bool AnyIsLiteral(const Any& any) { return typeid_int == typeid_any || typeid_float == typeid_any || typeid_bool == typeid_any; } -std::ostream& operator<<(std::ostream& os, const pybind11::object& obj) { +std::ostream &operator<<(std::ostream &os, const pybind11::object &obj) { os << "[py::object]"; return os; } -Any& Any::operator=(const Any& other) { +Any &Any::operator=(const Any &other) { if (m_ptr == other.m_ptr || &other == this) { return *this; } @@ -44,9 +44,9 @@ Any& Any::operator=(const Any& other) { return *this; } -bool Any::operator<(const Any& other) const { return this < &other; } +bool Any::operator<(const Any &other) const { return this < &other; } -Any& Any::operator=(Any&& other) { +Any &Any::operator=(Any &&other) { if (this != &other) { if (m_ptr == other.m_ptr || &other == this) { return *this; diff --git a/mindspore/ccsrc/utils/any.h b/mindspore/ccsrc/utils/any.h index ce691f1c12..b4edf602ac 100644 --- a/mindspore/ccsrc/utils/any.h +++ b/mindspore/ccsrc/utils/any.h @@ -35,23 +35,23 @@ namespace mindspore { // usage:AnyPtr sp = std::make_shared(aname); template -std::string type(const T& t) { +std::string type(const T &t) { return demangle(typeid(t).name()); } -std::ostream& operator<<(std::ostream& os, const pybind11::object& obj); +std::ostream &operator<<(std::ostream &os, const pybind11::object &obj); class Any { public: // constructors Any() : m_ptr(nullptr), m_tpIndex(std::type_index(typeid(void))) {} - Any(const Any& other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} - Any(Any&& other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} + Any(const Any &other) : m_ptr(other.clone()), m_tpIndex(other.m_tpIndex) {} + Any(Any &&other) : m_ptr(std::move(other.m_ptr)), m_tpIndex(std::move(other.m_tpIndex)) {} - Any& operator=(Any&& other); + Any &operator=(Any &&other); // right reference constructor template ::type, Any>::value, T>::type> - Any(T&& t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT + Any(T &&t) : m_tpIndex(typeid(typename std::decay::type)) { // NOLINT BasePtr new_val(new Derived::type>(std::forward(t))); std::swap(m_ptr, new_val); } @@ -67,7 +67,7 @@ class Any { return m_tpIndex == std::type_index(typeid(T)); } - const std::type_info& type() const { return m_ptr ? m_ptr->type() : typeid(void); } + const std::type_info &type() const { return m_ptr ? m_ptr->type() : typeid(void); } std::size_t Hash() const { std::stringstream buffer; @@ -79,7 +79,7 @@ class Any { } template - bool Apply(const std::function& fn) { + bool Apply(const std::function &fn) { if (type() == typeid(T)) { T x = cast(); fn(x); @@ -96,23 +96,23 @@ class Any { } } - friend std::ostream& operator<<(std::ostream& os, const Any& any) { + friend std::ostream &operator<<(std::ostream &os, const Any &any) { os << any.GetString(); return os; } // type cast template - T& cast() const { + T &cast() const { if (!is() || !m_ptr) { // Use MS_LOGFATAL replace throw std::bad_cast() MS_LOG(EXCEPTION) << "can not cast " << m_tpIndex.name() << " to " << typeid(T).name(); } - auto ptr = static_cast*>(m_ptr.get()); + auto ptr = static_cast *>(m_ptr.get()); return ptr->m_value; } - bool operator==(const Any& other) const { + bool operator==(const Any &other) const { if (m_tpIndex != other.m_tpIndex) { return false; } @@ -125,11 +125,11 @@ class Any { return *m_ptr == *other.m_ptr; } - bool operator!=(const Any& other) const { return !(operator==(other)); } + bool operator!=(const Any &other) const { return !(operator==(other)); } - Any& operator=(const Any& other); + Any &operator=(const Any &other); - bool operator<(const Any& other) const; + bool operator<(const Any &other) const; std::string ToString() const { std::ostringstream buffer; @@ -154,26 +154,26 @@ class Any { // type base definition struct Base { - virtual const std::type_info& type() const = 0; + virtual const std::type_info &type() const = 0; virtual BasePtr clone() const = 0; virtual ~Base() = default; - virtual bool operator==(const Base& other) const = 0; + virtual bool operator==(const Base &other) const = 0; virtual std::string GetString() = 0; }; template struct Derived : public Base { template - explicit Derived(Args&&... args) : m_value(std::forward(args)...), serialize_cache_("") {} + explicit Derived(Args &&... args) : m_value(std::forward(args)...), serialize_cache_("") {} - bool operator==(const Base& other) const override { + bool operator==(const Base &other) const override { if (typeid(*this) != typeid(other)) { return false; } - return m_value == static_cast&>(other).m_value; + return m_value == static_cast &>(other).m_value; } - const std::type_info& type() const override { return typeid(T); } + const std::type_info &type() const override { return typeid(T); } BasePtr clone() const override { return BasePtr(new Derived(m_value)); } @@ -204,14 +204,14 @@ class Any { using AnyPtr = std::shared_ptr; struct AnyHash { - std::size_t operator()(const Any& c) const { return c.Hash(); } + std::size_t operator()(const Any &c) const { return c.Hash(); } }; struct AnyLess { - bool operator()(const Any& a, const Any& b) const { return a.Hash() < b.Hash(); } + bool operator()(const Any &a, const Any &b) const { return a.Hash() < b.Hash(); } }; -bool AnyIsLiteral(const Any& any); +bool AnyIsLiteral(const Any &any); } // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref.cc b/mindspore/ccsrc/utils/base_ref.cc index e50f0003b8..aa38c8a6a0 100644 --- a/mindspore/ccsrc/utils/base_ref.cc +++ b/mindspore/ccsrc/utils/base_ref.cc @@ -17,17 +17,17 @@ #include "utils/base_ref.h" namespace mindspore { -iterator ConstIteratorCast(std::vector* v, const const_iterator iter) { +iterator ConstIteratorCast(std::vector *v, const const_iterator iter) { return std::next(v->begin(), std::distance(v->cbegin(), iter)); } -BaseRef::BaseRef(const BaseRef& other) : Base(other), m_ptr(other.m_ptr) { +BaseRef::BaseRef(const BaseRef &other) : Base(other), m_ptr(other.m_ptr) { if (!m_ptr) { m_ptr = other.copy(); } } -bool BaseRef::operator==(const BaseRef& other) const { +bool BaseRef::operator==(const BaseRef &other) const { if (m_ptr == other.m_ptr) { return true; } @@ -55,7 +55,7 @@ bool BaseRef::operator==(const BaseRef& other) const { } // left reference -BaseRef& BaseRef::operator=(const BaseRef& other) { +BaseRef &BaseRef::operator=(const BaseRef &other) { if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { return *this; } @@ -64,7 +64,7 @@ BaseRef& BaseRef::operator=(const BaseRef& other) { } // right reference -BaseRef& BaseRef::operator=(BaseRef&& other) { +BaseRef &BaseRef::operator=(BaseRef &&other) { if ((m_ptr != nullptr && m_ptr == other.m_ptr) || this == &other) { return *this; } @@ -88,7 +88,7 @@ uint32_t BaseRef::type() const { } // left reference -SetRef& SetRef::operator=(const SetRef& other) { +SetRef &SetRef::operator=(const SetRef &other) { if (elements_ == other.elements_ || this == &other) { return *this; } @@ -100,7 +100,7 @@ std::string SetRef::ToString() const { std::ostringstream buffer; bool begin = true; buffer << "set["; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -113,7 +113,7 @@ std::string SetRef::ToString() const { } // left reference -VectorRef& VectorRef::operator=(const VectorRef& other) { +VectorRef &VectorRef::operator=(const VectorRef &other) { if (elements_ == other.elements_ || this == &other) { return *this; } @@ -125,7 +125,7 @@ std::string VectorRef::ToString() const { std::ostringstream buffer; bool begin = true; buffer << "vector["; - for (auto& attr : elements_) { + for (auto &attr : elements_) { if (!begin) { buffer << ", "; } else { @@ -137,14 +137,14 @@ std::string VectorRef::ToString() const { return buffer.str(); } -bool VectorRef::operator==(const BaseRef& other) const { +bool VectorRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool VectorRef::operator==(const VectorRef& other) const { +bool VectorRef::operator==(const VectorRef &other) const { if (elements_.size() != other.elements_.size()) { return false; } @@ -156,14 +156,14 @@ bool VectorRef::operator==(const VectorRef& other) const { return true; } -bool SetRef::operator==(const BaseRef& other) const { +bool SetRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool SetRef::operator==(const SetRef& other) const { +bool SetRef::operator==(const SetRef &other) const { if (elements_.size() != other.elements_.size()) { return false; } @@ -177,21 +177,21 @@ bool SetRef::operator==(const SetRef& other) const { return true; } -bool RunFunctionRef::operator==(const BaseRef& other) const { +bool RunFunctionRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool RunFunctionRef::operator==(const RunFunctionRef& other) const { return func_ == other.func_; } +bool RunFunctionRef::operator==(const RunFunctionRef &other) const { return func_ == other.func_; } -bool PyObjectRef::operator==(const BaseRef& other) const { +bool PyObjectRef::operator==(const BaseRef &other) const { if (!utils::isa(other)) { return false; } return *this == utils::cast(other); } -bool PyObjectRef::operator==(const PyObjectRef& other) const { return object_ == other.object_; } +bool PyObjectRef::operator==(const PyObjectRef &other) const { return object_ == other.object_; } } // namespace mindspore diff --git a/mindspore/ccsrc/utils/base_ref.h b/mindspore/ccsrc/utils/base_ref.h index ed00d8280c..6e7911d0d9 100644 --- a/mindspore/ccsrc/utils/base_ref.h +++ b/mindspore/ccsrc/utils/base_ref.h @@ -40,7 +40,7 @@ using iterator = std::vector::iterator; using const_iterator = std::vector::const_iterator; using const_reverse_iterator = std::vector::const_reverse_iterator; -using RunFunc = std::function; +using RunFunc = std::function; using RunFuncPtr = std::shared_ptr; template @@ -54,9 +54,9 @@ using is_value = std::is_base_of>; template using is_base_ref = std::is_base_of>; -iterator ConstIteratorCast(std::vector* v, const_iterator iter); +iterator ConstIteratorCast(std::vector *v, const_iterator iter); -inline std::shared_ptr MakeNode(const std::vector& elements) { +inline std::shared_ptr MakeNode(const std::vector &elements) { return std::make_shared(elements); } @@ -68,34 +68,34 @@ inline std::shared_ptr MakeNode(std::initializer_list elemen template >::value && is_base::value, int>::type = 0> -inline BasePtr MakeNode(const T& v) { +inline BasePtr MakeNode(const T &v) { return v; } template >::value && !is_base_ref::value, int>::type = 0> -inline BasePtr MakeNode(const T& v) { +inline BasePtr MakeNode(const T &v) { return MakeValue(v); } -inline std::shared_ptr MakeNode(const VectorRef& a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const AnfNodePtrList& a) { +inline std::shared_ptr MakeNode(const VectorRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const AnfNodePtrList &a) { std::vector ret; - (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr& v) { return v; }); + (void)std::transform(a.begin(), a.end(), std::back_inserter(ret), [](const AnfNodePtr &v) { return v; }); return std::make_shared(ret); } -inline std::shared_ptr MakeNode(const SetRef& a) { return std::make_shared(std::move(a)); } -inline std::shared_ptr MakeNode(const RunFuncPtr& a) { return std::make_shared(a); } -inline std::shared_ptr MakeNode(const py::object& a) { return std::make_shared(a); } -inline std::shared_ptr MakeNode(const py::tuple& a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const SetRef &a) { return std::make_shared(std::move(a)); } +inline std::shared_ptr MakeNode(const RunFuncPtr &a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const py::object &a) { return std::make_shared(a); } +inline std::shared_ptr MakeNode(const py::tuple &a) { return std::make_shared(a); } class BaseRef : public Base { public: BaseRef() : m_ptr(nullptr) {} - BaseRef(const BaseRef& other); + BaseRef(const BaseRef &other); virtual std::shared_ptr copy() const { return m_ptr; } - BaseRef(BaseRef&& other) : Base(other) { + BaseRef(BaseRef &&other) : Base(other) { m_ptr = other.m_ptr; other.m_ptr = nullptr; } @@ -103,7 +103,7 @@ class BaseRef : public Base { // right reference constructor template ::type, BaseRef>::value, T>::type> - BaseRef(T&& t) { // NOLINT + BaseRef(T &&t) { // NOLINT m_ptr = MakeNode(t); } @@ -111,14 +111,14 @@ class BaseRef : public Base { MS_DECLARE_PARENT(BaseRef, Base) - bool operator!=(const BaseRef& other) const { return !(operator==(other)); } + bool operator!=(const BaseRef &other) const { return !(operator==(other)); } - virtual bool operator==(const BaseRef& other) const; + virtual bool operator==(const BaseRef &other) const; // left reference - virtual BaseRef& operator=(const BaseRef& other); + virtual BaseRef &operator=(const BaseRef &other); // right reference - virtual BaseRef& operator=(BaseRef&& other); + virtual BaseRef &operator=(BaseRef &&other); std::size_t hash() const override { if (m_ptr == nullptr) { @@ -139,18 +139,18 @@ class BaseRef : public Base { using BaseRefPtr = std::shared_ptr; struct BaseRefHash { - std::size_t operator()(const BaseRef& c) const { return c.hash(); } + std::size_t operator()(const BaseRef &c) const { return c.hash(); } }; struct BaseRefLess { - bool operator()(const BaseRef& a, const BaseRef& b) const { return a.hash() < b.hash(); } + bool operator()(const BaseRef &a, const BaseRef &b) const { return a.hash() < b.hash(); } }; namespace utils { // judge isa relation // examples: isa(handle), isa(handle) template ::value && !is_base_ref::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (!handle.m_ptr) { return false; } @@ -160,7 +160,7 @@ bool isa(const BaseRef& handle) { // noderef isa ptr isa(x) or isa() template ::value, typename T::element_type>::type, typename std::enable_if::value || is_base_ref::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return typeid(handle.m_ptr) == typeid(T); } @@ -175,7 +175,7 @@ bool isa(const BaseRef& handle) { // isa(handle) template ::type::element_type> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { if (handle.m_ptr == nullptr) { return false; } @@ -184,7 +184,7 @@ bool isa(const BaseRef& handle) { // isa(handle), judge reference or ptr template ::value, int>::type = 0> -bool isa(const BaseRef& handle) { +bool isa(const BaseRef &handle) { static const uint32_t tid = Base::GetTypeId(typeid(T).name()); return handle.IsFromTypeId(tid) || (handle.m_ptr && handle.m_ptr->isa()); } @@ -192,7 +192,7 @@ bool isa(const BaseRef& handle) { // valueref -> C++ type // cast(handle) template ::value && !is_shared_ptr::value, int>::type = 0> -T cast(const BaseRef& handle) { +T cast(const BaseRef &handle) { T ret = GetValue(std::static_pointer_cast(handle.m_ptr)); return std::move(ret); } @@ -200,12 +200,12 @@ T cast(const BaseRef& handle) { // valueref -> valueref type // cast(handle) template ::value, int>::type = 0> -const T& cast(const BaseRef& handle) { +const T &cast(const BaseRef &handle) { if (handle.m_ptr) { - return static_cast(*handle.m_ptr); + return static_cast(*handle.m_ptr); } - return std::move(static_cast(handle)); + return std::move(static_cast(handle)); } // valueref -> nodeptr type @@ -213,7 +213,7 @@ const T& cast(const BaseRef& handle) { template ::value, typename T::element_type>::type, typename std::enable_if::value && std::is_base_of::value, int>::type = 0> -T cast(const BaseRef& handle) { +T cast(const BaseRef &handle) { if (!handle.m_ptr) { MS_LOG(EXCEPTION) << "Can not cast to " << typeid(T).name() << ", pointer is null"; } @@ -229,11 +229,11 @@ T cast(const BaseRef& handle) { class VectorRef : public BaseRef { public: VectorRef() {} - explicit VectorRef(const std::vector& elements) : elements_(elements) {} - VectorRef(const const_iterator& begin, const const_iterator& end) : elements_(begin, end) {} + explicit VectorRef(const std::vector &elements) : elements_(elements) {} + VectorRef(const const_iterator &begin, const const_iterator &end) : elements_(begin, end) {} // left reference - virtual VectorRef& operator=(const VectorRef& other); + virtual VectorRef &operator=(const VectorRef &other); ~VectorRef() override = default; @@ -244,7 +244,7 @@ class VectorRef : public BaseRef { std::size_t size() const { return elements_.size(); } MS_DECLARE_PARENT(VectorRef, BaseRef) - const BaseRef& operator[](const std::size_t& dim) const { + const BaseRef &operator[](const std::size_t &dim) const { if (dim >= size()) { MS_LOG(EXCEPTION) << "Out of the size of the tuple."; } @@ -253,17 +253,17 @@ class VectorRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override; - std::vector& elements() { return elements_; } + std::vector &elements() { return elements_; } void clear() { elements_.clear(); } - bool operator==(const BaseRef& other) const override; - bool operator==(const VectorRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const VectorRef &other) const; - void push_back(const BaseRef& value) { elements_.push_back(value); } - void push_back(BaseRef&& value) { elements_.push_back(value); } + void push_back(const BaseRef &value) { elements_.push_back(value); } + void push_back(BaseRef &&value) { elements_.push_back(value); } - void emplace_back(const BaseRef& value) { elements_.emplace_back(value); } - void emplace_back(BaseRef&& value) { elements_.emplace_back(value); } + void emplace_back(const BaseRef &value) { elements_.emplace_back(value); } + void emplace_back(BaseRef &&value) { elements_.emplace_back(value); } template void insert(const iterator pos, const InputIt first, const InputIt last) { @@ -308,21 +308,21 @@ using set_iterator = std::set::iterator; using const_set_iterator = std::set::const_iterator; struct VectorRefHash { - std::size_t operator()(const VectorRef& c) const { return c.hash(); } + std::size_t operator()(const VectorRef &c) const { return c.hash(); } }; class SetRef : public BaseRef { public: SetRef() {} - explicit SetRef(const std::set& elements) : elements_(elements) {} + explicit SetRef(const std::set &elements) : elements_(elements) {} SetRef(const std::initializer_list elements) : elements_(elements.begin(), elements.end()) {} - SetRef(const const_set_iterator& begin, const const_set_iterator& end) : elements_(begin, end) {} + SetRef(const const_set_iterator &begin, const const_set_iterator &end) : elements_(begin, end) {} // left reference - virtual SetRef& operator=(const SetRef& other); + virtual SetRef &operator=(const SetRef &other); - bool operator==(const BaseRef& other) const override; - bool operator==(const SetRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const SetRef &other) const; ~SetRef() override = default; @@ -335,10 +335,10 @@ class SetRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override; - std::set& elements() { return elements_; } + std::set &elements() { return elements_; } void clear() { elements_.clear(); } - void insert(const BaseRef& elem) { (void)elements_.insert(elem); } + void insert(const BaseRef &elem) { (void)elements_.insert(elem); } const_set_iterator begin() const { return elements_.begin(); } const_set_iterator end() const { return elements_.end(); } @@ -348,8 +348,8 @@ class SetRef : public BaseRef { (void)elements_.insert(first, last); } - std::size_t count(const BaseRef& elem) const { return elements_.count(elem); } - const_set_iterator find(const BaseRef& elem) const { return elements_.find(elem); } + std::size_t count(const BaseRef &elem) const { return elements_.count(elem); } + const_set_iterator find(const BaseRef &elem) const { return elements_.find(elem); } std::set elements_; }; @@ -358,8 +358,8 @@ using SetRefPtr = std::shared_ptr; class PyObjectRef : public BaseRef { public: - explicit PyObjectRef(const py::object& py_object) : object_(py_object) {} - explicit PyObjectRef(const py::tuple& tuple_obj) : object_(tuple_obj) {} + explicit PyObjectRef(const py::object &py_object) : object_(py_object) {} + explicit PyObjectRef(const py::tuple &tuple_obj) : object_(tuple_obj) {} ~PyObjectRef() override = default; @@ -368,8 +368,8 @@ class PyObjectRef : public BaseRef { uint32_t type() const override { return tid(); } std::string ToString() const override { return py::str(object_); } - bool operator==(const BaseRef& other) const override; - bool operator==(const PyObjectRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const PyObjectRef &other) const; py::object object_; }; @@ -377,15 +377,15 @@ class PyObjectRef : public BaseRef { class RunFunctionRef : public BaseRef { public: RunFunctionRef() {} - explicit RunFunctionRef(const RunFuncPtr& ref_func) : func_(ref_func) {} + explicit RunFunctionRef(const RunFuncPtr &ref_func) : func_(ref_func) {} ~RunFunctionRef() override = default; MS_DECLARE_PARENT(RunFunctionRef, BaseRef) uint32_t type() const override { return tid(); } std::string ToString() const override { return std::string("RunFunctionRef"); } - bool operator==(const BaseRef& other) const override; - bool operator==(const RunFunctionRef& other) const; + bool operator==(const BaseRef &other) const override; + bool operator==(const RunFunctionRef &other) const; RunFuncPtr func_; }; diff --git a/mindspore/ccsrc/utils/callbacks.cc b/mindspore/ccsrc/utils/callbacks.cc index 03c6322afe..06bf1c73ab 100644 --- a/mindspore/ccsrc/utils/callbacks.cc +++ b/mindspore/ccsrc/utils/callbacks.cc @@ -37,14 +37,14 @@ const int ONE_SHAPE = 1; // Cache the summary callback data from ME session // Remove the GE module on new architecture // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...] -uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; py::list summary_list = py::list(); MS_LOG(INFO) << "The Summary save callback function for graph " << graph_id << ", Param list size = " << params_list.size() << "."; - for (auto& item : params_list) { + for (auto &item : params_list) { std::string tag_name = item.first; auto tensor_ptr = item.second; if (tensor_ptr == nullptr) { diff --git a/mindspore/ccsrc/utils/callbacks.h b/mindspore/ccsrc/utils/callbacks.h index a1e4e75d5b..9f46df0414 100644 --- a/mindspore/ccsrc/utils/callbacks.h +++ b/mindspore/ccsrc/utils/callbacks.h @@ -39,9 +39,9 @@ extern const std::string kPythonCheckpointFuncName; const int kCallbackOk = 0; const int kCallbackFalied = 1; -bool GetParameterShape(const FuncGraphPtr& anf_graph, const std::string& param_name, - const std::shared_ptr>& shape); -uint32_t SummarySaveCallback(uint32_t, const std::map&); +bool GetParameterShape(const FuncGraphPtr &anf_graph, const std::string ¶m_name, + const std::shared_ptr> &shape); +uint32_t SummarySaveCallback(uint32_t, const std::map &); } // namespace callbacks } // namespace mindspore diff --git a/mindspore/ccsrc/utils/callbacks_ge.cc b/mindspore/ccsrc/utils/callbacks_ge.cc index 36bbcbf297..b4c9fda634 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.cc +++ b/mindspore/ccsrc/utils/callbacks_ge.cc @@ -35,15 +35,15 @@ const int ONE_SHAPE = 1; using mindspore::transform::Status; using mindspore::transform::TransformUtil; -bool GetParameterShape(const FuncGraphPtr& graph, const std::string& param_name, - const std::shared_ptr>& shape) { +bool GetParameterShape(const FuncGraphPtr &graph, const std::string ¶m_name, + const std::shared_ptr> &shape) { if (graph == nullptr) { MS_LOG(ERROR) << "Graph is null, can not get graph parameter"; return false; } auto parameter_nodes = graph->parameters(); - for (auto& node : parameter_nodes) { + for (auto &node : parameter_nodes) { ParameterPtr param_node = std::static_pointer_cast(node); if (param_node == nullptr) { MS_LOG(ERROR) << "Parameter node is null, can not get graph parameter"; @@ -65,8 +65,8 @@ bool GetParameterShape(const FuncGraphPtr& graph, const std::string& param_name, return false; } -static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string& parameter_name, - const std::shared_ptr& ge_tensor_ptr) { +static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string ¶meter_name, + const std::shared_ptr &ge_tensor_ptr) { FuncGraphPtr anf_graph = transform::DfGraphManager::GetInstance().GetAnfGraph(graph_id); if (anf_graph == nullptr) { MS_LOG(ERROR) << "Get anf graph failed during callback"; @@ -82,13 +82,13 @@ static TensorPtr GetMeTensorTransformed(uint32_t graph_id, const std::string& pa return TransformUtil::ConvertGeTensor(ge_tensor_ptr, *parameter_shape_ptr); } -uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; MS_LOG(DEBUG) << "Start the checkpoint save callback function in checkpoint save process."; py::list parameter_list = py::list(); - for (auto& item : params_list) { + for (auto &item : params_list) { std::string name = item.first; std::shared_ptr ge_tensor_ptr = std::make_shared(item.second); TensorPtr tensor_ptr = GetMeTensorTransformed(graph_id, name, ge_tensor_ptr); @@ -112,7 +112,7 @@ uint32_t CheckpointSaveCallback(uint32_t graph_id, const std::map& ge_tensor_ptr) { +static TensorPtr GetMeTensorForSummary(const std::string &name, const std::shared_ptr &ge_tensor_ptr) { // confirm the type by name // Format: xxx[:Scalar] xxx[:Image] xxx[:Tensor] if (name.empty()) { @@ -149,14 +149,14 @@ static TensorPtr GetMeTensorForSummary(const std::string& name, const std::share // Cache the summary callback data // Output Format: [{"name": tag_name, "data": tensor}, {"name": tag_name, "data": tensor},...] -uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map& params_list) { +uint32_t MS_EXPORT SummarySaveCallback(uint32_t graph_id, const std::map ¶ms_list) { // Acquire GIL before calling Python code py::gil_scoped_acquire acquire; MS_LOG(DEBUG) << "Start the summary save callback function for graph " << graph_id << "."; py::list summary_list = py::list(); MS_LOG(DEBUG) << "Param list size = " << params_list.size(); - for (auto& item : params_list) { + for (auto &item : params_list) { std::string tag_name = item.first; std::shared_ptr ge_tensor_ptr = std::make_shared(item.second); TensorPtr tensor_ptr = GetMeTensorForSummary(tag_name, ge_tensor_ptr); diff --git a/mindspore/ccsrc/utils/callbacks_ge.h b/mindspore/ccsrc/utils/callbacks_ge.h index 750ec74666..08f5bb59db 100644 --- a/mindspore/ccsrc/utils/callbacks_ge.h +++ b/mindspore/ccsrc/utils/callbacks_ge.h @@ -29,8 +29,8 @@ namespace callbacks { using mindspore::tensor::TensorPtr; -uint32_t CheckpointSaveCallback(uint32_t, const std::map&); -uint32_t SummarySaveCallback(uint32_t, const std::map&); +uint32_t CheckpointSaveCallback(uint32_t, const std::map &); +uint32_t SummarySaveCallback(uint32_t, const std::map &); } // namespace callbacks } // namespace mindspore diff --git a/mindspore/ccsrc/utils/config_manager.cc b/mindspore/ccsrc/utils/config_manager.cc index 6d66b37436..7dc559b20e 100644 --- a/mindspore/ccsrc/utils/config_manager.cc +++ b/mindspore/ccsrc/utils/config_manager.cc @@ -22,12 +22,12 @@ namespace mindspore { -ConfigManager& ConfigManager::GetInstance() noexcept { +ConfigManager &ConfigManager::GetInstance() noexcept { static ConfigManager instance; return instance; } -void ConfigManager::SetDatasetModeConfig(const std::string& mode) { +void ConfigManager::SetDatasetModeConfig(const std::string &mode) { static const std::map mode_map = {{"normal", DS_NORMAL_MODE}, {"sink", DS_SINK_MODE}}; if (mode_map.find(mode) == mode_map.end()) { MS_LOG(ERROR) << "Invalid dataset mode:" << mode; diff --git a/mindspore/ccsrc/utils/config_manager.h b/mindspore/ccsrc/utils/config_manager.h index db7d7d0c14..635f24792a 100644 --- a/mindspore/ccsrc/utils/config_manager.h +++ b/mindspore/ccsrc/utils/config_manager.h @@ -37,8 +37,8 @@ enum DatasetMode { DS_NORMAL_MODE = 0, DS_SINK_MODE }; class DatasetGraphParam { public: - DatasetGraphParam(const std::string& name, int64_t size, int64_t batch_size, const std::vector& ge_types, - const std::vector>& shapes, const std::vector& input_indexes) + DatasetGraphParam(const std::string &name, int64_t size, int64_t batch_size, const std::vector &ge_types, + const std::vector> &shapes, const std::vector &input_indexes) : queue_name_(name), loop_size_(size), batch_size_(batch_size), @@ -72,15 +72,15 @@ class DatasetGraphParam { class ConfigManager { public: - ConfigManager(const ConfigManager&) = delete; - ConfigManager& operator=(const ConfigManager&) = delete; - static ConfigManager& GetInstance() noexcept; + ConfigManager(const ConfigManager &) = delete; + ConfigManager &operator=(const ConfigManager &) = delete; + static ConfigManager &GetInstance() noexcept; ParallelStrategy parallel_strategy() const { return parallel_strategy_; } void set_parallel_strategy(ParallelStrategy strategy) { parallel_strategy_ = strategy; } - const std::map& ge_initialize_options() const { return ge_initialize_options_; } - void set_ge_initialize_options(const std::map& options) { + const std::map &ge_initialize_options() const { return ge_initialize_options_; } + void set_ge_initialize_options(const std::map &options) { ge_initialize_options_ = options; } @@ -90,12 +90,12 @@ class ConfigManager { void set_iter_num(const int64_t num) { iter_num_ = num; } std::string dataset_phase() const { return dataset_phase_; } - void set_dataset_phase(const std::string& phase) { dataset_phase_ = phase; } + void set_dataset_phase(const std::string &phase) { dataset_phase_ = phase; } DatasetGraphParam dataset_param() const { return dataset_param_; } - void set_dataset_param(const DatasetGraphParam& param) { dataset_param_ = param; } + void set_dataset_param(const DatasetGraphParam ¶m) { dataset_param_ = param; } - static void SetDatasetModeConfig(const std::string& mode); + static void SetDatasetModeConfig(const std::string &mode); void ResetConfig() noexcept; diff --git a/mindspore/ccsrc/utils/context/ms_context.cc b/mindspore/ccsrc/utils/context/ms_context.cc index bee5875f60..0a2f065140 100644 --- a/mindspore/ccsrc/utils/context/ms_context.cc +++ b/mindspore/ccsrc/utils/context/ms_context.cc @@ -45,7 +45,7 @@ std::map MsContext::policy_map_ = {{"ge", kMsBacke {"ge_only", kMsBackendGeOnly}, {"vm_prior", kMsBackendVmPrior}}; -MsContext::MsContext(const std::string& policy, const std::string& target) { +MsContext::MsContext(const std::string &policy, const std::string &target) { save_graphs_flag_ = false; save_graphs_path_ = "."; save_ms_model_flag_ = false; @@ -97,7 +97,7 @@ std::shared_ptr MsContext::GetInstance() { return inst_context_; } -bool MsContext::set_backend_policy(const std::string& policy) { +bool MsContext::set_backend_policy(const std::string &policy) { if (policy_map_.find(policy) == policy_map_.end()) { MS_LOG(ERROR) << "invalid backend policy name: " << policy; return false; @@ -110,7 +110,7 @@ bool MsContext::set_backend_policy(const std::string& policy) { std::string MsContext::backend_policy() const { auto res = std::find_if( policy_map_.begin(), policy_map_.end(), - [&, this](const std::pair& item) { return item.second == backend_policy_; }); + [&, this](const std::pair &item) { return item.second == backend_policy_; }); if (res != policy_map_.end()) { return res->first; } @@ -124,7 +124,7 @@ void MsContext::set_execution_mode(int execution_mode) { execution_mode_ = execution_mode; } -bool MsContext::set_device_target(const std::string& target) { +bool MsContext::set_device_target(const std::string &target) { if (kTargetSet.find(target) == kTargetSet.end()) { MS_LOG(ERROR) << "invalid device target name: " << target; return false; @@ -218,7 +218,7 @@ bool MsContext::CloseTsd(bool force) { MS_LOG(INFO) << "join tdt host receive process"; tdt_print_.join(); } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "tdt thread join failed: " << e.what(); } #endif @@ -241,7 +241,7 @@ bool MsContext::OpenTsd() { return true; } bool MsContext::CloseTsd(bool) { return true; } #endif -void MsContext::SetHcclOptions(std::map* ge_options) const { +void MsContext::SetHcclOptions(std::map *ge_options) const { auto env_table_file = common::GetEnv("RANK_TABLE_FILE"); auto env_rank_id = common::GetEnv("RANK_ID"); auto env_device_id = std::to_string(device_id_); @@ -274,7 +274,7 @@ void MsContext::SetHcclOptions(std::map* ge_options) c } } -void MsContext::GetGeOptions(std::map* ge_options) const { +void MsContext::GetGeOptions(std::map *ge_options) const { #ifdef ENABLE_GE (*ge_options)["device_id"] = "0"; (*ge_options)["ge.exec.enableDump"] = std::to_string(enable_dump_); @@ -365,7 +365,7 @@ void MsContext::GetGeOptions(std::map* ge_options) con #endif } -void MsContext::SetDisableReuseMemoryFlag(std::map* ge_options) const { +void MsContext::SetDisableReuseMemoryFlag(std::map *ge_options) const { auto env_disable_reuse_memory = common::GetEnv("DISABLE_REUSE_MEMORY"); if (!env_disable_reuse_memory.empty()) { (*ge_options)["ge.exec.disableReuseMemory"] = env_disable_reuse_memory; @@ -412,7 +412,7 @@ bool MsContext::FinalizeGe(bool force) { try { DfGraphManager::GetInstance().DeleteGraphRunner(); DfGraphManager::GetInstance().DeleteGeSession(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Error occurred when deleting GE graph runner and session fail. Error: " << e.what(); } catch (...) { std::string exName(abi::__cxa_current_exception_type()->name()); diff --git a/mindspore/ccsrc/utils/context/ms_context.h b/mindspore/ccsrc/utils/context/ms_context.h index 06704ff9c6..1d84061a8a 100644 --- a/mindspore/ccsrc/utils/context/ms_context.h +++ b/mindspore/ccsrc/utils/context/ms_context.h @@ -48,13 +48,13 @@ const std::set kTargetSet = {kCPUDevice, kGPUDevice, kAscendDevice, class MsContext { public: ~MsContext() = default; - MsContext(const MsContext&) = delete; - MsContext& operator=(const MsContext&) = delete; + MsContext(const MsContext &) = delete; + MsContext &operator=(const MsContext &) = delete; static std::shared_ptr GetInstance(); std::string backend_policy() const; - bool set_backend_policy(const std::string& policy); + bool set_backend_policy(const std::string &policy); int execution_mode() const { return execution_mode_; } void set_execution_mode(int execution_mode); @@ -69,7 +69,7 @@ class MsContext { bool precompile_only() const { return precompile_only_; } std::string device_target() const { return device_target_; } - bool set_device_target(const std::string& target); + bool set_device_target(const std::string &target); uint32_t device_id() const { return device_id_; } bool set_device_id(uint32_t device_id); @@ -78,7 +78,7 @@ class MsContext { void set_save_graphs_flag(bool save_graphs_flag) { save_graphs_flag_ = save_graphs_flag; } std::string save_graphs_path() const { return save_graphs_path_; } - void set_save_graphs_path(const std::string& save_paths) { save_graphs_path_ = save_paths; } + void set_save_graphs_path(const std::string &save_paths) { save_graphs_path_ = save_paths; } bool OpenTsd(); bool CloseTsd(bool force = false); @@ -101,7 +101,7 @@ class MsContext { void set_save_ms_model_flag(bool save_ms_model_flag) { save_ms_model_flag_ = save_ms_model_flag; } std::string save_ms_model_path() const { return save_ms_model_path_; } - void set_save_ms_model_path(const std::string& save_ms_model_path) { save_ms_model_path_ = save_ms_model_path; } + void set_save_ms_model_path(const std::string &save_ms_model_path) { save_ms_model_path_ = save_ms_model_path; } void set_enable_gpu_summary(bool enable_gpu_summary) { enable_gpu_summary_ = enable_gpu_summary; } bool enable_gpu_summary() const { return enable_gpu_summary_; } @@ -117,7 +117,7 @@ class MsContext { void set_enable_dump(bool flag) { enable_dump_ = flag; } bool enable_dump() const { return enable_dump_; } - void set_save_dump_path(const std::string& path) { save_dump_path_ = path; } + void set_save_dump_path(const std::string &path) { save_dump_path_ = path; } std::string save_dump_path() const { return save_dump_path_; } bool IsTsdOpened() const { return tsd_ref_ > 0; } @@ -128,19 +128,19 @@ class MsContext { void set_enable_dynamic_mem_pool(bool enable_dynamic_mem_pool) { enable_dynamic_mem_pool_ = enable_dynamic_mem_pool; } bool enable_dynamic_mem_pool() const { return enable_dynamic_mem_pool_; } - void set_graph_memory_max_size(const std::string& graph_memory_max_size) { + void set_graph_memory_max_size(const std::string &graph_memory_max_size) { graph_memory_max_size_ = graph_memory_max_size; } - void set_variable_memory_max_size(const std::string& variable_memory_max_size) { + void set_variable_memory_max_size(const std::string &variable_memory_max_size) { variable_memory_max_size_ = variable_memory_max_size; } private: - MsContext(const std::string& backend_policy, const std::string& target); - void GetGeOptions(std::map* ge_options) const; - void SetDisableReuseMemoryFlag(std::map* ge_options) const; - void SetHcclOptions(std::map* ge_options) const; + MsContext(const std::string &backend_policy, const std::string &target); + void GetGeOptions(std::map *ge_options) const; + void SetDisableReuseMemoryFlag(std::map *ge_options) const; + void SetHcclOptions(std::map *ge_options) const; static std::shared_ptr inst_context_; static std::map policy_map_; diff --git a/mindspore/ccsrc/utils/counter.h b/mindspore/ccsrc/utils/counter.h index 891f9c7a35..ead0ad84f2 100644 --- a/mindspore/ccsrc/utils/counter.h +++ b/mindspore/ccsrc/utils/counter.h @@ -29,17 +29,17 @@ class Counter { Counter() = default; ~Counter() = default; - Counter(const Counter& other) { data = other.data; } - Counter& operator=(const Counter& other) { + Counter(const Counter &other) { data = other.data; } + Counter &operator=(const Counter &other) { if (this != &other) { data = other.data; } return *this; } - int& operator[](const T& t) { return data[t]; } + int &operator[](const T &t) { return data[t]; } - counter_type operator-(const counter_type& other) { + counter_type operator-(const counter_type &other) { counter_type new_counter; for (auto iter = begin(); iter != end(); ++iter) { auto key = iter->first; @@ -58,7 +58,7 @@ class Counter { return new_counter; } - counter_type operator+(const counter_type& other) { + counter_type operator+(const counter_type &other) { counter_type new_counter; for (auto iter = begin(); iter != end(); ++iter) { auto key = iter->first; @@ -84,7 +84,7 @@ class Counter { std::size_t size() const { return data.size(); } - bool contains(const T& t) const { return data.find(t) != data.end(); } + bool contains(const T &t) const { return data.find(t) != data.end(); } typename OrderedMap::iterator begin() { return data.begin(); } diff --git a/mindspore/ccsrc/utils/graph_utils.cc b/mindspore/ccsrc/utils/graph_utils.cc index 55ef8dc3d5..0801622549 100644 --- a/mindspore/ccsrc/utils/graph_utils.cc +++ b/mindspore/ccsrc/utils/graph_utils.cc @@ -39,10 +39,10 @@ using SymbolicKeyTypePtr = std::shared_ptr; namespace { class DeepFirstSearcher : public AnfVisitor { public: - explicit DeepFirstSearcher(const IncludeFunc& include) : include_(include) {} + explicit DeepFirstSearcher(const IncludeFunc &include) : include_(include) {} ~DeepFirstSearcher() override = default; - std::vector Search(const AnfNodePtr& root) { + std::vector Search(const AnfNodePtr &root) { if (root == nullptr) { return res_; } @@ -50,7 +50,7 @@ class DeepFirstSearcher : public AnfVisitor { return res_; } - void Visit(const AnfNodePtr& node) override { + void Visit(const AnfNodePtr &node) override { MS_EXCEPTION_IF_NULL(node); if (seen_.count(node) != 0) { return; @@ -77,10 +77,10 @@ class DeepFirstSearcher : public AnfVisitor { class DeepScopedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepScopedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepScopedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepScopedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { + void Visit(const CNodePtr &cnode) override { if (cnode->func_graph() == nullptr) { return; } @@ -90,13 +90,13 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { DeepFirstSearcher::Visit(ret); } - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr& vnode) override { + void Visit(const ValueNodePtr &vnode) override { if (!IsValueNode(vnode)) { return; } @@ -108,7 +108,7 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { } } - void Visit(const ParameterPtr& param) override { + void Visit(const ParameterPtr ¶m) override { if (param->func_graph() == nullptr) { return; } @@ -122,17 +122,17 @@ class DeepScopedGraphSearcher : public DeepFirstSearcher { class DeepUsedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepUsedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepUsedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepUsedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { - auto& inputs = cnode->inputs(); + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr& vnode) override { + void Visit(const ValueNodePtr &vnode) override { if (!IsValueNode(vnode)) { return; } @@ -147,33 +147,33 @@ class DeepUsedGraphSearcher : public DeepFirstSearcher { class DeepLinkedGraphSearcher : public DeepFirstSearcher { public: - explicit DeepLinkedGraphSearcher(const IncludeFunc& include) : DeepFirstSearcher(include) {} + explicit DeepLinkedGraphSearcher(const IncludeFunc &include) : DeepFirstSearcher(include) {} ~DeepLinkedGraphSearcher() override = default; - void Visit(const CNodePtr& cnode) override { - auto& inputs = cnode->inputs(); + void Visit(const CNodePtr &cnode) override { + auto &inputs = cnode->inputs(); for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { DeepFirstSearcher::Visit(*iter); } } - void Visit(const ValueNodePtr&) override {} + void Visit(const ValueNodePtr &) override {} }; } // namespace -std::vector DeepScopedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepScopedGraphSearcher(include).Search(root); } -std::vector DeepUsedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepUsedGraphSearcher(include).Search(root); } -std::vector DeepLinkedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include) { +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include) { return DeepLinkedGraphSearcher(include).Search(root); } -std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ, const IncludeFunc& include) { +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ, const IncludeFunc &include) { std::unordered_set done; std::list todo(1, root); std::unordered_map rank; @@ -222,7 +222,7 @@ std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ, c return res; } -std::vector SuccDeeper(const AnfNodePtr& node) { +std::vector SuccDeeper(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; @@ -237,7 +237,7 @@ std::vector SuccDeeper(const AnfNodePtr& node) { return vecs; } else if (node->func_graph() != nullptr) { if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } auto graph = node->func_graph(); @@ -250,7 +250,7 @@ std::vector SuccDeeper(const AnfNodePtr& node) { return vecs; } -std::vector SuccDeeperSimple(const AnfNodePtr& node) { +std::vector SuccDeeperSimple(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; @@ -265,39 +265,39 @@ std::vector SuccDeeperSimple(const AnfNodePtr& node) { return vecs; } else { if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } return vecs; } } -std::vector SuccIncoming(const AnfNodePtr& node) { +std::vector SuccIncoming(const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; } if (node->isa()) { - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); (void)vecs.insert(vecs.end(), inputs.begin(), inputs.end()); } return vecs; } -std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& node) { +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node) { std::vector vecs; if (node == nullptr) { return vecs; } if (node->isa()) { auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); // Check if free variables used. - for (const auto& input : inputs) { + for (const auto &input : inputs) { auto input_fg = GetValueNode(input); if (input_fg) { - for (auto& fv : input_fg->free_variables_nodes()) { + for (auto &fv : input_fg->free_variables_nodes()) { if (fv->func_graph() == fg && fg->nodes().contains(fv)) { vecs.push_back(fv); } @@ -309,9 +309,9 @@ std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& return vecs; } -IncludeType AlwaysInclude(const AnfNodePtr&) { return FOLLOW; } +IncludeType AlwaysInclude(const AnfNodePtr &) { return FOLLOW; } -IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node) { +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node) { if (node->func_graph() == fg) { return FOLLOW; } else { @@ -319,12 +319,12 @@ IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node) { } } -FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search, const IncludeFunc& include) { +FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search, const IncludeFunc &include) { MS_EXCEPTION_IF_NULL(fg); Acquire(fg); auto vec = search(fg->get_return(), include); - for (auto& node : vec) { + for (auto &node : vec) { MS_EXCEPTION_IF_NULL(node); Acquire(node); if (node->func_graph() != nullptr) { @@ -333,7 +333,7 @@ FuncGraphIndex::FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search, } } -std::set FuncGraphIndex::GetFuncGraphs(const std::string& key) { +std::set FuncGraphIndex::GetFuncGraphs(const std::string &key) { std::set func_graphs; if (index_func_graph_.find(key) != index_func_graph_.end()) { func_graphs = index_func_graph_[key]; @@ -341,7 +341,7 @@ std::set FuncGraphIndex::GetFuncGraphs(const std::string& key) { return func_graphs; } -std::set FuncGraphIndex::GetNodes(const std::string& key) { +std::set FuncGraphIndex::GetNodes(const std::string &key) { if (index_node_.find(key) != index_node_.end()) { return index_node_[key]; } @@ -349,7 +349,7 @@ std::set FuncGraphIndex::GetNodes(const std::string& key) { return std::set(); } -FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string& key) { +FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string &key) { if (GetFuncGraphs(key).empty()) { return nullptr; } @@ -358,7 +358,7 @@ FuncGraphPtr FuncGraphIndex::GetFirstFuncGraph(const std::string& key) { return fg; } -AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string& key) { +AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string &key) { if (GetNodes(key).empty()) { return nullptr; } @@ -367,14 +367,14 @@ AnfNodePtr FuncGraphIndex::GetFirstNode(const std::string& key) { return node; } -void FuncGraphIndex::Acquire(const FuncGraphPtr& key) { +void FuncGraphIndex::Acquire(const FuncGraphPtr &key) { std::string name = label_manage::Label(key->debug_info()); if (!name.empty()) { (void)index_func_graph_[name].insert(key); } } -void FuncGraphIndex::Acquire(const AnfNodePtr& key) { +void FuncGraphIndex::Acquire(const AnfNodePtr &key) { std::string name = label_manage::Label(key->debug_info()); if (!name.empty()) { (void)index_node_[name].insert(key); @@ -382,8 +382,8 @@ void FuncGraphIndex::Acquire(const AnfNodePtr& key) { } // Isomorphism -static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameNodeShallow(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { if (equiv_node == nullptr) { MS_LOG(ERROR) << "Invalid equiv_node"; return false; @@ -419,13 +419,13 @@ static bool SameNodeShallow(const AnfNodePtr& node1, const AnfNodePtr& node2, Fu return false; } -static bool SameNode(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { MS_EXCEPTION_IF_NULL(node1); MS_EXCEPTION_IF_NULL(node2); if (node1->isa() && node2->isa()) { - auto& inputs1 = node1->cast()->inputs(); - auto& inputs2 = node2->cast()->inputs(); + auto &inputs1 = node1->cast()->inputs(); + auto &inputs2 = node2->cast()->inputs(); for (std::size_t i = 0; i < inputs1.size(); ++i) { if (!SameNodeShallow(inputs1[i], inputs2[i], equiv_func_graph, equiv_node)) { return false; @@ -436,8 +436,8 @@ static bool SameNode(const AnfNodePtr& node1, const AnfNodePtr& node2, FuncGraph return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node); } -static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { std::unordered_set done; std::stack> todo; @@ -479,8 +479,8 @@ static bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEqu return true; } -bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv* equiv_func_graph, - NodeMapEquiv* const equiv_node) { +bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *const equiv_node) { auto fg1_fg2 = std::make_pair(fg1, fg2); if (equiv_func_graph == nullptr) { MS_LOG(ERROR) << "equiv_func_graph not init"; @@ -511,7 +511,7 @@ bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv* equiv return false; } -tensor::TensorPtr ScalarToTensor(const ScalarPtr& scalar) { +tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { if (scalar == nullptr) { MS_EXCEPTION(ArgumentError) << "Nullptr Error!"; } diff --git a/mindspore/ccsrc/utils/graph_utils.h b/mindspore/ccsrc/utils/graph_utils.h index 57bc0e42fc..d01335af82 100644 --- a/mindspore/ccsrc/utils/graph_utils.h +++ b/mindspore/ccsrc/utils/graph_utils.h @@ -38,42 +38,42 @@ namespace mindspore { enum IncludeType { FOLLOW, NOFOLLOW, EXCLUDE }; -using IncludeFunc = std::function; +using IncludeFunc = std::function; using SuccFunc = std::function(AnfNodePtr)>; -using SearchFunc = std::function(const AnfNodePtr&, const IncludeFunc&)>; +using SearchFunc = std::function(const AnfNodePtr &, const IncludeFunc &)>; -std::vector SuccDeeper(const AnfNodePtr& node); -std::vector SuccDeeperSimple(const AnfNodePtr& node); -std::vector SuccIncoming(const AnfNodePtr& node); -std::vector SuccIncludeFV(const FuncGraphPtr& fg, const AnfNodePtr& node); +std::vector SuccDeeper(const AnfNodePtr &node); +std::vector SuccDeeperSimple(const AnfNodePtr &node); +std::vector SuccIncoming(const AnfNodePtr &node); +std::vector SuccIncludeFV(const FuncGraphPtr &fg, const AnfNodePtr &node); -IncludeType AlwaysInclude(const AnfNodePtr& node); -IncludeType IncludeBelongGraph(const FuncGraphPtr& fg, const AnfNodePtr& node); +IncludeType AlwaysInclude(const AnfNodePtr &node); +IncludeType IncludeBelongGraph(const FuncGraphPtr &fg, const AnfNodePtr &node); -std::vector DeepScopedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); -std::vector DeepUsedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); -std::vector DeepLinkedGraphSearch(const AnfNodePtr& root, const IncludeFunc& include = AlwaysInclude); +std::vector DeepScopedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepUsedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); +std::vector DeepLinkedGraphSearch(const AnfNodePtr &root, const IncludeFunc &include = AlwaysInclude); -std::vector TopoSort(const AnfNodePtr& root, const SuccFunc& succ = SuccIncoming, - const IncludeFunc& include = AlwaysInclude); +std::vector TopoSort(const AnfNodePtr &root, const SuccFunc &succ = SuccIncoming, + const IncludeFunc &include = AlwaysInclude); class FuncGraphIndex { public: - explicit FuncGraphIndex(const FuncGraphPtr& fg, const SearchFunc& search = DeepScopedGraphSearch, - const IncludeFunc& include = AlwaysInclude); - FuncGraphIndex(const FuncGraphIndex&) = delete; - FuncGraphIndex& operator=(const FuncGraphIndex&) = delete; + explicit FuncGraphIndex(const FuncGraphPtr &fg, const SearchFunc &search = DeepScopedGraphSearch, + const IncludeFunc &include = AlwaysInclude); + FuncGraphIndex(const FuncGraphIndex &) = delete; + FuncGraphIndex &operator=(const FuncGraphIndex &) = delete; virtual ~FuncGraphIndex() {} - std::set GetFuncGraphs(const std::string& key); - std::set GetNodes(const std::string& key); - FuncGraphPtr GetFirstFuncGraph(const std::string& key); - AnfNodePtr GetFirstNode(const std::string& key); + std::set GetFuncGraphs(const std::string &key); + std::set GetNodes(const std::string &key); + FuncGraphPtr GetFirstFuncGraph(const std::string &key); + AnfNodePtr GetFirstNode(const std::string &key); private: - void Acquire(const FuncGraphPtr& key); - void Acquire(const AnfNodePtr& key); + void Acquire(const FuncGraphPtr &key); + void Acquire(const AnfNodePtr &key); std::map> index_func_graph_; std::map> index_node_; @@ -83,7 +83,7 @@ class FuncGraphIndex { struct PairHasher { template - std::size_t operator()(const std::pair& p) const { + std::size_t operator()(const std::pair &p) const { auto h1 = std::hash{}(p.first); auto h2 = std::hash{}(p.second); return h1 ^ h2; @@ -95,9 +95,9 @@ enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 }; using FuncGraphPairMapEquiv = std::unordered_map, EquivState, PairHasher>; using NodeMapEquiv = std::unordered_map; -bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv* equiv_func_graph, NodeMapEquiv* equiv_node); +bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); -tensor::TensorPtr ScalarToTensor(const ScalarPtr& scalar); +tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_GRAPH_UTILS_H_ diff --git a/mindspore/ccsrc/utils/hashing.h b/mindspore/ccsrc/utils/hashing.h index 730657ce7a..cc8cc5b991 100644 --- a/mindspore/ccsrc/utils/hashing.h +++ b/mindspore/ccsrc/utils/hashing.h @@ -25,7 +25,7 @@ inline std::size_t hash_combine(std::size_t hash_sum, std::size_t hash_val) { return ((hash_sum << 6) + (hash_sum >> 2) + 0x9e3779b9 + hash_val) ^ hash_sum; } -inline std::size_t hash_combine(const std::initializer_list& hash_vals) { +inline std::size_t hash_combine(const std::initializer_list &hash_vals) { std::size_t hash_sum = 0; for (auto hash_val : hash_vals) { hash_sum = hash_combine(hash_sum, hash_val); diff --git a/mindspore/ccsrc/utils/misc.cc b/mindspore/ccsrc/utils/misc.cc index 47e675a341..a9eb8071ef 100644 --- a/mindspore/ccsrc/utils/misc.cc +++ b/mindspore/ccsrc/utils/misc.cc @@ -23,9 +23,9 @@ const int RET_FAILED = 1; const int RET_CONTINUE = 2; const int RET_BREAK = 3; -std::string demangle(const char* name) { +std::string demangle(const char *name) { int status = -1; - std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; + std::unique_ptr res{abi::__cxa_demangle(name, nullptr, nullptr, &status), std::free}; return (status == 0) ? res.get() : name; } } // namespace mindspore diff --git a/mindspore/ccsrc/utils/misc.h b/mindspore/ccsrc/utils/misc.h index 66e8937f9c..e2cdebe98a 100644 --- a/mindspore/ccsrc/utils/misc.h +++ b/mindspore/ccsrc/utils/misc.h @@ -33,7 +33,7 @@ extern const int RET_CONTINUE; extern const int RET_BREAK; // demangle the name to make it human reablable. -extern std::string demangle(const char* name); +extern std::string demangle(const char *name); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_MISC_H_ diff --git a/mindspore/ccsrc/utils/ordered_set.h b/mindspore/ccsrc/utils/ordered_set.h index b22053f196..f393ce74f2 100644 --- a/mindspore/ccsrc/utils/ordered_set.h +++ b/mindspore/ccsrc/utils/ordered_set.h @@ -53,28 +53,28 @@ class OrderedSet { // OrderedSet use an iterator to list as mapped value to improve the performance of insertion and deletion, // So copy of OrderedSet should re-build value of the map key to make it pointer to the new list,, thus we use // traversal to build elements. - OrderedSet(const OrderedSet& os) { - for (auto& item : os.ordered_data_) { + OrderedSet(const OrderedSet &os) { + for (auto &item : os.ordered_data_) { add(item); } } - explicit OrderedSet(const sequential_type& other) { - for (auto& item : other) { + explicit OrderedSet(const sequential_type &other) { + for (auto &item : other) { add(item); } } // Explicitly construct an OrderedSet use vector - explicit OrderedSet(const vector_type& other) { - for (auto& item : other) { + explicit OrderedSet(const vector_type &other) { + for (auto &item : other) { add(item); } } - OrderedSet& operator=(const OrderedSet& os) { + OrderedSet &operator=(const OrderedSet &os) { if (this != &os) { - for (auto& item : os.ordered_data_) { + for (auto &item : os.ordered_data_) { add(item); } } @@ -82,14 +82,14 @@ class OrderedSet { } // Add an element to the OrderedSet, without judging return value - void add(const element_type& e) { (void)insert(e); } + void add(const element_type &e) { (void)insert(e); } // insert an element to the OrderedSet - std::pair insert(const element_type& e) { + std::pair insert(const element_type &e) { iterator empty_itr; std::pair map_pair = std::make_pair(e, empty_itr); auto result = mapped_data_.insert(map_pair); - auto& seq_idx = result.first->second; + auto &seq_idx = result.first->second; // if insert success; if (result.second) { auto it = ordered_data_.insert(ordered_data_.end(), e); @@ -99,7 +99,7 @@ class OrderedSet { } // Remove an element, if removed return true, otherwise return false - bool erase(const element_type& e) { + bool erase(const element_type &e) { auto pos = mapped_data_.find(e); if (pos == mapped_data_.end()) { return false; @@ -119,7 +119,7 @@ class OrderedSet { std::string toString() { std::ostringstream res; res << "orderset content:\n"; - for (auto& item : ordered_data_) { + for (auto &item : ordered_data_) { res << std::to_string(reinterpret_cast(item.get())) << " "; } return res.str(); @@ -132,7 +132,7 @@ class OrderedSet { } // Compare two orderedset, if the order is not equal shall return false - bool operator==(const OrderedSet& other) const { return ordered_data_ == other.ordered_data_; } + bool operator==(const OrderedSet &other) const { return ordered_data_ == other.ordered_data_; } // Remove and return the first element in the OrderedSet T pop() { @@ -153,8 +153,8 @@ class OrderedSet { } // Return true if there are no common elements - bool is_disjoint(const OrderedSet& other) { - for (auto& item : other.ordered_data_) { + bool is_disjoint(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { if (mapped_data_.find(item) != mapped_data_.end()) { return false; } @@ -163,8 +163,8 @@ class OrderedSet { } // Test whether this is subset of other - bool is_subset(const OrderedSet& other) { - for (auto& item : ordered_data_) { + bool is_subset(const OrderedSet &other) { + for (auto &item : ordered_data_) { if (other.mapped_data_.find(item) == other.mapped_data_.end()) { return false; } @@ -173,51 +173,51 @@ class OrderedSet { } // Add elements in other to this orderedset - void update(const OrderedSet& other) { - for (auto& item : other.ordered_data_) { + void update(const OrderedSet &other) { + for (auto &item : other.ordered_data_) { add(item); } } - void update(const std::shared_ptr& other) { update(*other); } + void update(const std::shared_ptr &other) { update(*other); } - void update(const sequential_type& other) { - for (auto& item : other) { + void update(const sequential_type &other) { + for (auto &item : other) { add(item); } } - void update(const vector_type& other) { - for (auto& item : other) { + void update(const vector_type &other) { + for (auto &item : other) { add(item); } } - ordered_set_type get_union(const OrderedSet& other) { + ordered_set_type get_union(const OrderedSet &other) { ordered_set_type res(ordered_data_); res.update(other); return res; } // Get the union with other set, this operator may cost time because of copy - ordered_set_type operator|(const OrderedSet& other) { return get_union(other); } + ordered_set_type operator|(const OrderedSet &other) { return get_union(other); } // Return the intersection of two sets - ordered_set_type intersection(const OrderedSet& other) { + ordered_set_type intersection(const OrderedSet &other) { ordered_set_type res(ordered_data_); - for (auto& item : ordered_data_) { + for (auto &item : ordered_data_) { if (other.mapped_data_.find(item) == other.mapped_data_.end()) { (void)res.erase(item); } } return res; } - ordered_set_type operator&(const OrderedSet& other) { return intersection(other); } + ordered_set_type operator&(const OrderedSet &other) { return intersection(other); } // Return the symmetric difference of two sets - ordered_set_type symmetric_difference(const OrderedSet& other) { + ordered_set_type symmetric_difference(const OrderedSet &other) { ordered_set_type res(ordered_data_); - for (auto& item : other.ordered_data_) { + for (auto &item : other.ordered_data_) { if (mapped_data_.find(item) != mapped_data_.end()) { (void)res.erase(item); } else { @@ -227,40 +227,40 @@ class OrderedSet { return res; } - ordered_set_type operator^(const OrderedSet& other) { return symmetric_difference(other); } + ordered_set_type operator^(const OrderedSet &other) { return symmetric_difference(other); } // Remove elements which is also in others. - void difference_update(const OrderedSet& other) { + void difference_update(const OrderedSet &other) { // use vector traversal, to keep ordrer - for (auto& item : other.ordered_data_) { + for (auto &item : other.ordered_data_) { (void)erase(item); } } - void difference_update(const sequential_type& other) { - for (auto& item : other) { + void difference_update(const sequential_type &other) { + for (auto &item : other) { (void)erase(item); } } - void difference_update(const vector_type& other) { - for (auto& item : other) { + void difference_update(const vector_type &other) { + for (auto &item : other) { (void)erase(item); } } // Return the set with elements that are not in the others - ordered_set_type difference(const OrderedSet& other) { + ordered_set_type difference(const OrderedSet &other) { ordered_set_type res(ordered_data_); res.difference_update(other); return res; } - ordered_set_type operator-(const OrderedSet& other) { return difference(other); } + ordered_set_type operator-(const OrderedSet &other) { return difference(other); } - bool contains(const element_type& e) const { return (mapped_data_.find(e) != mapped_data_.end()); } + bool contains(const element_type &e) const { return (mapped_data_.find(e) != mapped_data_.end()); } // Return the count of an element in set - std::size_t count(const element_type& e) const { return mapped_data_.count(e); } + std::size_t count(const element_type &e) const { return mapped_data_.count(e); } iterator begin() { return ordered_data_.begin(); } iterator end() { return ordered_data_.end(); } diff --git a/mindspore/ccsrc/utils/profile.cc b/mindspore/ccsrc/utils/profile.cc index ba490549f8..997cc1b56d 100644 --- a/mindspore/ccsrc/utils/profile.cc +++ b/mindspore/ccsrc/utils/profile.cc @@ -33,11 +33,11 @@ namespace { constexpr size_t TIME_INFO_PREFIX_NUM_LEN = 4; const char KEY_PROF_TOTAL[] = "__total__"; -void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent = 0, - std::map* sums = nullptr, const std::string& prefix = ""); +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = ""); -void PrintTimeInfoMap(std::ostringstream& oss, const TimeInfoMap& dict, int indent = 0, - std::map* sums = nullptr, const std::string& prefix = "") { +void PrintTimeInfoMap(std::ostringstream &oss, const TimeInfoMap &dict, int indent = 0, + std::map *sums = nullptr, const std::string &prefix = "") { for (auto iter = dict.begin(); iter != dict.end(); ++iter) { if (iter->second == nullptr) { continue; @@ -62,8 +62,8 @@ void PrintTimeInfoMap(std::ostringstream& oss, const TimeInfoMap& dict, int inde } } -void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent, std::map* sums, - const std::string& prefix) { +void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent, std::map *sums, + const std::string &prefix) { bool need_free = false; if (sums == nullptr) { sums = new (std::nothrow) std::map(); @@ -95,7 +95,7 @@ void PrintProfile(std::ostringstream& oss, const TimeInfo& time_info, int indent } oss << "Sums\n"; if (total >= 0.0 + DBL_EPSILON) { - for (auto& iter : *sums) { + for (auto &iter : *sums) { std::string name = iter.first; name.erase(0, TIME_INFO_PREFIX_NUM_LEN); std::size_t pos = 0; @@ -159,7 +159,7 @@ void Profile::Print(void) { // Start a step in the current context with the given name. // Nomes must be unique otherwise the previous record will be overwritten. -ProfContext* Profile::Step(const std::string& name) { +ProfContext *Profile::Step(const std::string &name) { ctx_ptr_ = new (std::nothrow) ProfContext(name, this); if (ctx_ptr_ == nullptr) { MS_LOG(ERROR) << "memory allocation failed"; @@ -170,7 +170,7 @@ ProfContext* Profile::Step(const std::string& name) { // Creates subcontext for a repeated action. // Count should be monotonically increasing. -ProfContext* Profile::Lap(int count) { +ProfContext *Profile::Lap(int count) { std::ostringstream oss; oss << "Cycle " << count; ctx_ptr_ = new (std::nothrow) ProfContext(oss.str(), this); @@ -188,7 +188,7 @@ void Profile::Pop(void) noexcept { ctx_ptr_ = ctx_ptr_->parent_; } -ProfContext::ProfContext(const std::string& name, ProfileBase* const prof) : name_(name), prof_(prof) { +ProfContext::ProfContext(const std::string &name, ProfileBase *const prof) : name_(name), prof_(prof) { // Initialize a subcontext. time_info_ = nullptr; if (prof == nullptr || IsTopContext()) { @@ -227,7 +227,7 @@ void ProfContext::SetTime(double time) noexcept { time_info_->time_ = time; } -void ProfContext::Insert(const std::string& name, const TimeInfo* time) noexcept { +void ProfContext::Insert(const std::string &name, const TimeInfo *time) noexcept { if (time_info_ == nullptr) { time_info_ = new (std::nothrow) TimeInfo(); if (time_info_ == nullptr) { @@ -266,7 +266,7 @@ void ProfContext::Insert(const std::string& name, const TimeInfo* time) noexcept bool ProfContext::IsTopContext() const noexcept { return (prof_ != nullptr) && (this == &prof_->context_); } -ProfTransaction::ProfTransaction(const ProfileBase* prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } +ProfTransaction::ProfTransaction(const ProfileBase *prof) { ctx_ = (prof != nullptr ? prof->ctx_ptr_ : nullptr); } ProfTransaction::~ProfTransaction() { if (ctx_ != nullptr && !ctx_->IsTopContext()) { @@ -275,7 +275,7 @@ ProfTransaction::~ProfTransaction() { ctx_ = nullptr; } -void DumpTime::Record(const std::string& step_name, const double time, const bool is_start) { +void DumpTime::Record(const std::string &step_name, const double time, const bool is_start) { file_ss_ << " {" << std::endl; file_ss_ << " \"name\": " << "\"" << step_name << "\"," << std::endl; @@ -298,7 +298,7 @@ void DumpTime::Record(const std::string& step_name, const double time, const boo void DumpTime::Save() { try { file_out_.open(file_path_, std::ios::trunc | std::ios::out); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(EXCEPTION) << "Cannot open file in " << (file_path_); } file_out_ << "{\n"; @@ -317,10 +317,10 @@ struct TimeInfoGroup { std::list::const_iterator> items; }; -static void PrintTimeStat(std::ostringstream& oss, const TimeInfoGroup& group, const std::string& prefix) { +static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, const std::string &prefix) { oss << "------[" << prefix << "] " << std::setw(10) << std::fixed << std::setprecision(6) << group.total_time << std::setw(6) << group.total_count << "\n"; - for (const auto& iter : group.items) { + for (const auto &iter : group.items) { oss << std::setw(5) << std::fixed << std::setprecision(2) << iter->second.time_ / group.total_time * 100 << "% : " << std::setw(12) << std::fixed << std::setprecision(6) << iter->second.time_ << "s : " << std::setw(6) << iter->second.count_ << ": " << iter->first << "\n"; @@ -332,7 +332,7 @@ void MsProfile::Print() { std::vector items = {"substitution.", "renormalize.", "replace.", "match.", "func_graph_cloner_run.", "meta_graph.", "manager."}; std::vector groups(items.size() + 1); - const auto& stat = GetSingleton().time_stat_; + const auto &stat = GetSingleton().time_stat_; // group all time infos for (auto iter = stat.cbegin(); iter != stat.cend(); ++iter) { auto matched_idx = items.size(); diff --git a/mindspore/ccsrc/utils/profile.h b/mindspore/ccsrc/utils/profile.h index 6892b0b4f6..bd3723d5bb 100644 --- a/mindspore/ccsrc/utils/profile.h +++ b/mindspore/ccsrc/utils/profile.h @@ -27,7 +27,7 @@ namespace mindspore { struct TimeInfo; -using TimeInfoMap = std::map; +using TimeInfoMap = std::map; extern double GetTime(); @@ -35,11 +35,11 @@ class ProfileBase; struct TimeInfo { explicit TimeInfo(double time = -1.0) : time_(time), dict_(nullptr), actionNum_(0) {} - TimeInfo(const TimeInfo&) = delete; + TimeInfo(const TimeInfo &) = delete; ~TimeInfo(); double time_; - TimeInfoMap* dict_; + TimeInfoMap *dict_; size_t actionNum_; }; @@ -50,21 +50,21 @@ class ProfContext { friend class ProfTransaction; public: - ProfContext(const std::string& name, ProfileBase* prof); + ProfContext(const std::string &name, ProfileBase *prof); ~ProfContext(); - ProfContext(const ProfContext&) = delete; - ProfContext& operator=(const ProfContext&) = delete; + ProfContext(const ProfContext &) = delete; + ProfContext &operator=(const ProfContext &) = delete; void SetTime(double time) noexcept; - void Insert(const std::string& name, const TimeInfo* time) noexcept; + void Insert(const std::string &name, const TimeInfo *time) noexcept; bool IsTopContext() const noexcept; private: std::string name_; - ProfileBase* prof_; - ProfContext* parent_; - TimeInfo* time_info_; + ProfileBase *prof_; + ProfContext *parent_; + TimeInfo *time_info_; }; class ProfileBase { @@ -76,38 +76,38 @@ class ProfileBase { virtual ~ProfileBase(); virtual void Print(void) {} - virtual ProfContext* Step(const std::string&) { return nullptr; } - virtual ProfContext* Lap(int) { return nullptr; } + virtual ProfContext *Step(const std::string &) { return nullptr; } + virtual ProfContext *Lap(int) { return nullptr; } virtual void Pop(void) {} // top level profile context ProfContext context_; // profile context pointer, act as a stack pointer - ProfContext* ctx_ptr_ = nullptr; + ProfContext *ctx_ptr_ = nullptr; }; class Profile : public ProfileBase { public: Profile() = default; ~Profile() override = default; - Profile(const Profile&) = delete; - Profile& operator=(const Profile&) = delete; + Profile(const Profile &) = delete; + Profile &operator=(const Profile &) = delete; void Print(void) override; - ProfContext* Step(const std::string& name) override; - ProfContext* Lap(int count) override; + ProfContext *Step(const std::string &name) override; + ProfContext *Lap(int count) override; void Pop(void) noexcept override; }; class ProfTransaction { public: - explicit ProfTransaction(const ProfileBase* prof); - explicit ProfTransaction(ProfContext* const ctx) : ctx_(ctx) {} - ProfTransaction(const ProfTransaction&) = delete; + explicit ProfTransaction(const ProfileBase *prof); + explicit ProfTransaction(ProfContext *const ctx) : ctx_(ctx) {} + ProfTransaction(const ProfTransaction &) = delete; ~ProfTransaction(); template - void operator-(const Function& func) { + void operator-(const Function &func) { double start_time = GetTime(); func(); double end_time = GetTime(); @@ -117,17 +117,17 @@ class ProfTransaction { } private: - ProfContext* ctx_ = nullptr; + ProfContext *ctx_ = nullptr; }; class NoProfTransaction { public: - explicit NoProfTransaction(ProfileBase* prof) {} - explicit NoProfTransaction(ProfContext* ctx) {} + explicit NoProfTransaction(ProfileBase *prof) {} + explicit NoProfTransaction(ProfContext *ctx) {} ~NoProfTransaction() = default; template - void operator-(const Function& func) { + void operator-(const Function &func) { func(); } }; @@ -137,20 +137,20 @@ class DumpTime { ~DumpTime() { try { Save(); - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Cannot save file by profile::DumpTime::save"; } catch (...) { MS_LOG(ERROR) << "Uncaught exception"; } } - DumpTime(const DumpTime&) = delete; - DumpTime& operator=(const DumpTime&) = delete; - static DumpTime& GetInstance() { + DumpTime(const DumpTime &) = delete; + DumpTime &operator=(const DumpTime &) = delete; + static DumpTime &GetInstance() { static DumpTime instance; return instance; } - void set_file_path(const std::string& save_path) { file_path_ = save_path; } - void Record(const std::string& name, const double time, const bool is_start); + void set_file_path(const std::string &save_path) { file_path_ = save_path; } + void Record(const std::string &name, const double time, const bool is_start); void Save(); private: @@ -188,8 +188,8 @@ class MsProfile { static void Reset() { GetSingleton().Clear(); } - static ProfileBase* GetProfile() { - MsProfile& ms_prof = GetSingleton(); + static ProfileBase *GetProfile() { + MsProfile &ms_prof = GetSingleton(); if (ms_prof.profile_ == nullptr) { #ifdef ENABLE_PROFILE ms_prof.profile_ = new Profile(); @@ -199,14 +199,14 @@ class MsProfile { } return ms_prof.profile_; } - static void StatTime(const std::string& id, double time) { GetSingleton().time_stat_[id] += time; } + static void StatTime(const std::string &id, double time) { GetSingleton().time_stat_[id] += time; } static void Print(); private: MsProfile() = default; - static MsProfile& GetSingleton() { + static MsProfile &GetSingleton() { static MsProfile profile; return profile; } @@ -220,7 +220,7 @@ class MsProfile { } std::map time_stat_; // record time and count info from some activity - ProfileBase* profile_ = nullptr; // record hierarchical profile info + ProfileBase *profile_ = nullptr; // record hierarchical profile info }; } // namespace mindspore diff --git a/mindspore/ccsrc/utils/signal.h b/mindspore/ccsrc/utils/signal.h index af7b36a8b5..9a43e23814 100644 --- a/mindspore/ccsrc/utils/signal.h +++ b/mindspore/ccsrc/utils/signal.h @@ -24,14 +24,14 @@ namespace mindspore { template -std::function bind_member(Type* instance, Return (Type::*method)(Args...)) { - return [=](Args&&... args) -> Return { return (instance->*method)(std::forward(args)...); }; +std::function bind_member(Type *instance, Return (Type::*method)(Args...)) { + return [=](Args &&... args) -> Return { return (instance->*method)(std::forward(args)...); }; } template class Slot { public: - explicit Slot(const std::function& callback) : callback(callback) {} + explicit Slot(const std::function &callback) : callback(callback) {} ~Slot() {} @@ -42,15 +42,15 @@ template class Signal { public: template - void operator()(Args&&... args) { - for (auto& slot : slots_) { + void operator()(Args &&... args) { + for (auto &slot : slots_) { if (slot->callback != nullptr) { slot->callback(std::forward(args)...); } } } - void add_slot(const std::function& func) { + void add_slot(const std::function &func) { auto slot = std::make_shared>(func); slots_.push_back(slot); } diff --git a/mindspore/ccsrc/utils/symbolic.cc b/mindspore/ccsrc/utils/symbolic.cc index 8764678288..8ad16e50c8 100644 --- a/mindspore/ccsrc/utils/symbolic.cc +++ b/mindspore/ccsrc/utils/symbolic.cc @@ -22,29 +22,29 @@ namespace mindspore { -std::ostream& operator<<(std::ostream& out, const std::shared_ptr& objPtr) { +std::ostream &operator<<(std::ostream &out, const std::shared_ptr &objPtr) { out << "("; MS_EXCEPTION_IF_NULL(objPtr); - for (auto& iter : objPtr->contents_) { + for (auto &iter : objPtr->contents_) { out << iter.first << ":" << iter.second << ";"; } out << ")"; return out; } -bool EnvInstance::operator==(const EnvInstance& other) const { +bool EnvInstance::operator==(const EnvInstance &other) const { if (Len() != other.Len()) { return false; } bool equal = std::all_of(contents_.begin(), contents_.end(), - [&other](const std::pair& item) -> bool { + [&other](const std::pair &item) -> bool { return other.contents_.find(item.first) != other.contents_.end(); }); return equal; } -bool EnvInstance::operator==(const Value& other) const { +bool EnvInstance::operator==(const Value &other) const { if (other.isa()) { - auto other_env_inst = static_cast(&other); + auto other_env_inst = static_cast(&other); return *this == *other_env_inst; } return false; diff --git a/mindspore/ccsrc/utils/symbolic.h b/mindspore/ccsrc/utils/symbolic.h index 3c712483ee..a373c23573 100644 --- a/mindspore/ccsrc/utils/symbolic.h +++ b/mindspore/ccsrc/utils/symbolic.h @@ -32,18 +32,18 @@ namespace mindspore { class SymbolicKeyInstance : public Value { public: - SymbolicKeyInstance(const AnfNodePtr& node, const abstract::AbstractBasePtr& abstract) + SymbolicKeyInstance(const AnfNodePtr &node, const abstract::AbstractBasePtr &abstract) : node_(node), abstract_(abstract) {} ~SymbolicKeyInstance() override = default; MS_DECLARE_PARENT(SymbolicKeyInstance, Value); AnfNodePtr node() const { return node_; } abstract::AbstractBasePtr abstract() const { return abstract_; } - bool operator==(const SymbolicKeyInstance& other) const { + bool operator==(const SymbolicKeyInstance &other) const { return (*node_ == *other.node_) && (*abstract_ == *other.abstract_); } std::size_t hash() const override { return std::hash{}(node_); } - friend std::ostream& operator<<(std::ostream& os, const std::shared_ptr& inst) { + friend std::ostream &operator<<(std::ostream &os, const std::shared_ptr &inst) { if (inst == nullptr) { os << "[Key][" << "Invalid symbolic key instance" @@ -56,9 +56,9 @@ class SymbolicKeyInstance : public Value { std::string ToString() const override { return node_ == nullptr ? "Invalid node" : "[Key][" + node_->type_name() + "]" + node_->ToString(); } - bool operator==(const Value& other) const override { + bool operator==(const Value &other) const override { if (other.isa()) { - auto other_ = static_cast(other); + auto other_ = static_cast(other); return *this == other_; } else { return false; @@ -106,19 +106,19 @@ using EnvInstanceContentsMap = // with inferred properties. class EnvInstance : public Value { public: - friend std::ostream& operator<<(std::ostream& out, const std::shared_ptr& env); + friend std::ostream &operator<<(std::ostream &out, const std::shared_ptr &env); - explicit EnvInstance(const EnvInstanceContentsMap& contents = {}) : contents_(contents) {} + explicit EnvInstance(const EnvInstanceContentsMap &contents = {}) : contents_(contents) {} ~EnvInstance() override = default; MS_DECLARE_PARENT(EnvInstance, Value); abstract::AbstractBasePtr ToAbstract() override { return std::make_shared(shared_from_base(), std::make_shared()); } - bool operator==(const EnvInstance& other) const; - bool operator==(const Value& other) const override; - EnvInstance(const EnvInstance& v) : Value(v), contents_(v.contents_) {} - EnvInstance(EnvInstance&& v) = default; - EnvInstance& operator=(EnvInstance&& src) noexcept { + bool operator==(const EnvInstance &other) const; + bool operator==(const Value &other) const override; + EnvInstance(const EnvInstance &v) : Value(v), contents_(v.contents_) {} + EnvInstance(EnvInstance &&v) = default; + EnvInstance &operator=(EnvInstance &&src) noexcept { if (&src != this) { contents_ = src.contents_; } @@ -126,7 +126,7 @@ class EnvInstance : public Value { }; // Get the sensitivity list for the given key - const Any& Get(const SymbolicKeyInstancePtr& key, const Any& def) const { + const Any &Get(const SymbolicKeyInstancePtr &key, const Any &def) const { auto iterator = contents_.find(key); if (iterator != contents_.end()) { return iterator->second; @@ -135,14 +135,14 @@ class EnvInstance : public Value { } // Set a value for the given key. - EnvInstance Set(const SymbolicKeyInstancePtr& key, const Any& value) const { + EnvInstance Set(const SymbolicKeyInstancePtr &key, const Any &value) const { EnvInstance rval(contents_); rval.contents_[key] = value; return rval; } // Add two EnvInstances. - EnvInstance Add(const EnvInstance& other) const { + EnvInstance Add(const EnvInstance &other) const { EnvInstance rval(contents_); for (auto iter_other : other.contents_) { auto item_self = contents_.find(iter_other.first); diff --git a/mindspore/ccsrc/utils/system/base.h b/mindspore/ccsrc/utils/system/base.h index dace2e7178..4cfb5b312d 100644 --- a/mindspore/ccsrc/utils/system/base.h +++ b/mindspore/ccsrc/utils/system/base.h @@ -108,7 +108,7 @@ constexpr bool kLittleEndian = true; // implement common define function // Get the 32 bits align value -inline uint32 DecodeFixed32(const char* ptr) { +inline uint32 DecodeFixed32(const char *ptr) { uint32 result; if (EOK != memcpy_s(&result, sizeof(result), ptr, sizeof(result))) { MS_LOG(EXCEPTION) << "Call DecodeFixed32 memcpy value failure."; @@ -116,14 +116,14 @@ inline uint32 DecodeFixed32(const char* ptr) { return result; } // Used to fetch a naturally-aligned 32-bit word in little endian byte-order -inline uint32 LE_LOAD32(const uint8_t* p) { return DecodeFixed32(reinterpret_cast(p)); } +inline uint32 LE_LOAD32(const uint8_t *p) { return DecodeFixed32(reinterpret_cast(p)); } // Encode the data to buffer -inline void EncodeFixed32(char* buf, uint32 value) { +inline void EncodeFixed32(char *buf, uint32 value) { if (EOK != memcpy_s(buf, sizeof(value), &value, sizeof(value))) { MS_LOG(EXCEPTION) << "Call EncodeFixed32 memcpy value failure."; } } -inline void EncodeFixed64(char* buf, const unsigned int array_len, int64 value) { +inline void EncodeFixed64(char *buf, const unsigned int array_len, int64 value) { if (sizeof(value) > array_len) { MS_LOG(EXCEPTION) << "Buffer overflow, real size is " << array_len << ", but required " << sizeof(value) << "."; } diff --git a/mindspore/ccsrc/utils/system/crc32c.h b/mindspore/ccsrc/utils/system/crc32c.h index 4411423bab..d23b9ad463 100644 --- a/mindspore/ccsrc/utils/system/crc32c.h +++ b/mindspore/ccsrc/utils/system/crc32c.h @@ -40,10 +40,10 @@ class Crc32c { ~Crc32c() = default; // Calculate the crc32c value, use the 8 table method - static uint32 MakeCrc32c(uint32 init_crc, const char* data, size_t size); + static uint32 MakeCrc32c(uint32 init_crc, const char *data, size_t size); // retrun the crc32c value(need mask) - static uint32 GetMaskCrc32cValue(const char* data, size_t n) { + static uint32 GetMaskCrc32cValue(const char *data, size_t n) { auto crc = MakeCrc32c(0, data, n); // Rotate right by kRightShift bits and add kMaskDelta(a constant). return ((crc >> kRightShift) | (crc << kLeftShift)) + kMaskDelta; diff --git a/mindspore/ccsrc/utils/system/file_system.cc b/mindspore/ccsrc/utils/system/file_system.cc index aee89d4b7b..ce27108a39 100644 --- a/mindspore/ccsrc/utils/system/file_system.cc +++ b/mindspore/ccsrc/utils/system/file_system.cc @@ -25,7 +25,7 @@ namespace system { #if defined(SYSTEM_ENV_POSIX) // Implement the Posix file systen -WriteFilePtr PosixFileSystem::CreateWriteFile(const string& file_name) { +WriteFilePtr PosixFileSystem::CreateWriteFile(const string &file_name) { if (file_name.empty()) { MS_LOG(ERROR) << "Create write file failed because the file name is null."; return nullptr; @@ -43,7 +43,7 @@ WriteFilePtr PosixFileSystem::CreateWriteFile(const string& file_name) { return fp; } -bool PosixFileSystem::FileExist(const string& file_name) { +bool PosixFileSystem::FileExist(const string &file_name) { if (file_name.empty()) { MS_LOG(WARNING) << "The file name is null."; return false; @@ -56,7 +56,7 @@ bool PosixFileSystem::FileExist(const string& file_name) { return true; } -bool PosixFileSystem::DeleteFile(const string& file_name) { +bool PosixFileSystem::DeleteFile(const string &file_name) { if (file_name.empty()) { MS_LOG(WARNING) << "The file name is null."; return false; @@ -70,7 +70,7 @@ bool PosixFileSystem::DeleteFile(const string& file_name) { } static const int DEFAULT_MKDIR_MODE = 0700; -bool PosixFileSystem::CreateDir(const string& dir_name) { +bool PosixFileSystem::CreateDir(const string &dir_name) { if (dir_name.empty()) { MS_LOG(WARNING) << "The directory name is null."; return false; @@ -83,7 +83,7 @@ bool PosixFileSystem::CreateDir(const string& dir_name) { return true; } -bool PosixFileSystem::DeleteDir(const string& dir_name) { +bool PosixFileSystem::DeleteDir(const string &dir_name) { if (dir_name.empty()) { MS_LOG(WARNING) << "The directory name is null."; return false; diff --git a/mindspore/ccsrc/utils/system/file_system.h b/mindspore/ccsrc/utils/system/file_system.h index ef0cf885be..ed9db874c8 100644 --- a/mindspore/ccsrc/utils/system/file_system.h +++ b/mindspore/ccsrc/utils/system/file_system.h @@ -45,25 +45,25 @@ class FileSystem { virtual ~FileSystem() = default; // Create a new read/write file - virtual WriteFilePtr CreateWriteFile(const string& file_name) = 0; + virtual WriteFilePtr CreateWriteFile(const string &file_name) = 0; // Check the file is exist? - virtual bool FileExist(const string& file_name) = 0; + virtual bool FileExist(const string &file_name) = 0; // Delete the file - virtual bool DeleteFile(const string& file_name) = 0; + virtual bool DeleteFile(const string &file_name) = 0; // Create a directory - virtual bool CreateDir(const string& dir_name) = 0; + virtual bool CreateDir(const string &dir_name) = 0; // Delete the specified directory - virtual bool DeleteDir(const string& dir_name) = 0; + virtual bool DeleteDir(const string &dir_name) = 0; }; // A file that can be read and write class WriteFile { public: - explicit WriteFile(const string& file_name) : file_name_(file_name) {} + explicit WriteFile(const string &file_name) : file_name_(file_name) {} virtual ~WriteFile() = default; @@ -71,7 +71,7 @@ class WriteFile { virtual bool Open() = 0; // append the content to file - virtual bool Write(const std::string& data) { + virtual bool Write(const std::string &data) { MS_LOG(WARNING) << "Attention: Maybe not call the function."; return true; } @@ -101,27 +101,27 @@ class PosixFileSystem : public FileSystem { ~PosixFileSystem() override = default; // create a new write file - WriteFilePtr CreateWriteFile(const string& file_name) override; + WriteFilePtr CreateWriteFile(const string &file_name) override; // check the file is exist? - bool FileExist(const string& file_name) override; + bool FileExist(const string &file_name) override; // delete the file - bool DeleteFile(const string& file_name) override; + bool DeleteFile(const string &file_name) override; // Create a Directory - bool CreateDir(const string& dir_name) override; + bool CreateDir(const string &dir_name) override; // Delete the specified directory. - bool DeleteDir(const string& dir_name) override; + bool DeleteDir(const string &dir_name) override; }; // A file that can be read and write for posix class PosixWriteFile : public WriteFile { public: - explicit PosixWriteFile(const string& file_name) : WriteFile(file_name), file_(nullptr) {} - PosixWriteFile(const PosixWriteFile&); - PosixWriteFile& operator=(const PosixWriteFile&); + explicit PosixWriteFile(const string &file_name) : WriteFile(file_name), file_(nullptr) {} + PosixWriteFile(const PosixWriteFile &); + PosixWriteFile &operator=(const PosixWriteFile &); ~PosixWriteFile() override { try { @@ -129,7 +129,7 @@ class PosixWriteFile : public WriteFile { (void)fclose(file_); file_ = nullptr; } - } catch (const std::exception& e) { + } catch (const std::exception &e) { MS_LOG(ERROR) << "Exception when closing file."; } catch (...) { MS_LOG(ERROR) << "Non standard exception when closing file."; @@ -159,7 +159,7 @@ class PosixWriteFile : public WriteFile { return true; } - bool Write(const std::string& data) override { + bool Write(const std::string &data) override { MS_LOG(DEBUG) << "Write data(" << data.size() << ") to file(" << this->file_name_ << ")."; size_t r = fwrite(data.data(), 1, data.size(), file_); if (r != data.size()) { @@ -194,7 +194,7 @@ class PosixWriteFile : public WriteFile { bool Sync() override { return Flush(); } private: - FILE* file_; + FILE *file_; }; #endif diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index eac1b86273..f05eda69bf 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -213,7 +213,7 @@ const std::set kOptOperatorSet = { const std::set kNeedTransFormatSet = {kOpFormat_FRAC_Z, kOpFormat_NC1KHKWHWC0, kOpFormat_NC1HWC0, kOpFormat_FRAC_NZ, kOpFormat_C1HWNCoC0}; -static inline void ChangeFileMode(const std::string& file_name, mode_t mode) { +static inline void ChangeFileMode(const std::string &file_name, mode_t mode) { if (access(file_name.c_str(), F_OK) != 0) { MS_LOG(DEBUG) << "File `" << file_name << "` does not exist."; return; diff --git a/mindspore/ccsrc/vm/segment_runner.cc b/mindspore/ccsrc/vm/segment_runner.cc index d7d5a4c096..ae052770ff 100644 --- a/mindspore/ccsrc/vm/segment_runner.cc +++ b/mindspore/ccsrc/vm/segment_runner.cc @@ -47,7 +47,7 @@ void ClearConvertCache() { g_ConvertCache.clear(); } // lst: list of nodes (the segment) // users: dict mapping each node to its users (globally) // seen: set of nodes that are part of the segment -AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, const std::vector& seen) { +AnfNodePtrList GetOutput(const AnfNodePtrList &lst, const NodeUsersMap &users, const std::vector &seen) { AnfNodePtrList output; if (users.size() == 0) { return output; @@ -57,7 +57,7 @@ AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, c std::begin(lst), std::end(lst), std::back_inserter(output), [&users, &seen](AnfNodePtr n) -> AnfNodePtr { auto usersn = users.find(n); bool is_referred_out_of_segment = std::any_of( - std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair& u) -> bool { + std::begin(usersn->second), std::end(usersn->second), [&seen](const std::pair &u) -> bool { return std::find(std::begin(seen), std::end(seen), u.first) == std::end(seen); }); if (n->isa() && is_referred_out_of_segment) { @@ -78,7 +78,7 @@ AnfNodePtrList GetOutput(const AnfNodePtrList& lst, const NodeUsersMap& users, c return output; } -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList& lst) { +std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst) { auto fg = std::make_shared(); AnfNodePtrList inputs; AnfNodePtrToAnfNodePtrMap eqv; @@ -86,7 +86,7 @@ std::tuple TransformSegmentToAnfGr MS_LOG(EXCEPTION) << "Input anf node list is empty"; } - auto ref = [&eqv, &inputs, &fg](const AnfNodePtr& a) -> AnfNodePtr { + auto ref = [&eqv, &inputs, &fg](const AnfNodePtr &a) -> AnfNodePtr { if (a->isa() && !IsValueNode(a)) { eqv[a] = a; } else if (eqv.find(a) == eqv.end()) { @@ -102,7 +102,7 @@ std::tuple TransformSegmentToAnfGr if (!n->isa()) { MS_LOG(EXCEPTION) << "Inst is not CNode"; } - auto& inps = n->cast()->inputs(); + auto &inps = n->cast()->inputs(); if (inps.empty()) { MS_LOG(EXCEPTION) << "Input is empty"; @@ -120,13 +120,13 @@ std::tuple TransformSegmentToAnfGr std::vector eqv_keys; (void)std::transform(std::begin(eqv), std::end(eqv), std::back_inserter(eqv_keys), - [](const std::pair& elem) -> AnfNodePtr { return elem.first; }); + [](const std::pair &elem) -> AnfNodePtr { return elem.first; }); auto outputs = GetOutput(lst, lst[0]->func_graph()->manager()->node_users(), eqv_keys); std::vector output_args; output_args.push_back(NewValueNode(prim::kPrimMakeTuple)); (void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(output_args), - [&eqv](const AnfNodePtr& o) -> AnfNodePtr { return eqv[o]; }); + [&eqv](const AnfNodePtr &o) -> AnfNodePtr { return eqv[o]; }); // Set output for AnfGraph auto fg_output = fg->NewCNode(output_args); @@ -148,7 +148,7 @@ std::tuple TransformSegmentToAnfGr // This implementation will convert the nodes into a subgraph // that will run using the MsVM. template -LinConvertResult Convert(const AnfNodePtrList& lst) { +LinConvertResult Convert(const AnfNodePtrList &lst) { auto cached = g_ConvertCache.find(lst); if (cached != g_ConvertCache.end()) { return cached->second; @@ -168,7 +168,7 @@ LinConvertResult Convert(const AnfNodePtrList& lst) { std::shared_ptr vm = std::make_shared(); result.run = - std::make_shared([fg, vm](const VectorRef& args) -> VectorRef { return vm->RunGraph(fg, args); }); + std::make_shared([fg, vm](const VectorRef &args) -> VectorRef { return vm->RunGraph(fg, args); }); result.inputs = inputs; result.outputs = outputs; result.graph_id = UINT32_MAX; diff --git a/mindspore/ccsrc/vm/segment_runner.h b/mindspore/ccsrc/vm/segment_runner.h index 112a770de8..8ea87da50c 100644 --- a/mindspore/ccsrc/vm/segment_runner.h +++ b/mindspore/ccsrc/vm/segment_runner.h @@ -43,7 +43,7 @@ struct LinConvertResult { uint32_t graph_id; }; -using LinkFuncType = std::function; +using LinkFuncType = std::function; using ConvertCache = std::unordered_map; extern LinkFuncType MsVmConvert; extern LinkFuncType GeVmConvert; @@ -53,7 +53,7 @@ extern std::set backend_list; void ClearConvertCache(); -std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList& lst); +std::tuple TransformSegmentToAnfGraph(const AnfNodePtrList &lst); } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/vm/transform.cc b/mindspore/ccsrc/vm/transform.cc index 92976e0ddb..1c3c917dae 100644 --- a/mindspore/ccsrc/vm/transform.cc +++ b/mindspore/ccsrc/vm/transform.cc @@ -41,12 +41,12 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple}; -const std::vector& GetMsNonlinearOps() { +const std::vector &GetMsNonlinearOps() { static const std::vector ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch}; return ms_nonlinear_ops; } -CompileGraph::CompileGraph(const BackendPtr& backend, const std::vector& cut_list) +CompileGraph::CompileGraph(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend), cut_list_(cut_list) { MS_EXCEPTION_IF_NULL(backend_); lin_convert_ = backend_->convert_fn(); @@ -61,11 +61,11 @@ CompileGraph::CompileGraph(const BackendPtr& backend, const std::vectorisa()) { auto cnode = node->cast(); - auto& inputs = cnode->inputs(); + auto &inputs = cnode->inputs(); if (inputs.empty()) { MS_LOG(EXCEPTION) << "Inputs of apply node is empty"; } @@ -76,7 +76,7 @@ bool CompileGraph::IsCut(const AnfNodePtr& node) { } PrimitivePtr node_prim = GetValueNode(fn); - for (auto& prim : cut_list_) { + for (auto &prim : cut_list_) { MS_EXCEPTION_IF_NULL(prim); if (prim->name() == node_prim->name()) { return true; @@ -97,14 +97,14 @@ bool CompileGraph::IsCut(const AnfNodePtr& node) { return false; } -VectorRef CompileGraph::SplitNodes(const FuncGraphPtr& graph) { +VectorRef CompileGraph::SplitNodes(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); VectorRef splits; VectorRef split; std::vector nodes = TopoSort(graph->get_return()); MS_LOG(DEBUG) << "Split all nodes size:" << nodes.size(); - for (auto& node : nodes) { + for (auto &node : nodes) { MS_EXCEPTION_IF_NULL(node); if (IsCut(node)) { MS_LOG(DEBUG) << "Cut node:" << node->DebugString(10) << ", size:" << split.size(); @@ -123,7 +123,7 @@ VectorRef CompileGraph::SplitNodes(const FuncGraphPtr& graph) { } // Push the value node on the stack. -void CompileGraph::Push(const AnfNodePtr& node) { +void CompileGraph::Push(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (slots_.count(node) > 0) { MS_LOG(EXCEPTION) << "Push failed node in slots:" << node->DebugString() @@ -135,25 +135,25 @@ void CompileGraph::Push(const AnfNodePtr& node) { set_height(height_ + 1); } -void CompileGraph::AddInst(const Instruction& inst, const int& arg) { +void CompileGraph::AddInst(const Instruction &inst, const int &arg) { VectorRef args; args.push_back(arg); AddInst(inst, args); } -void CompileGraph::AddInst(const Instruction& inst, const ValuePtr& arg) { +void CompileGraph::AddInst(const Instruction &inst, const ValuePtr &arg) { VectorRef args; args.push_back(arg); AddInst(inst, args); } -void CompileGraph::AddInst(const Instruction& inst, const VectorRef& args) { +void CompileGraph::AddInst(const Instruction &inst, const VectorRef &args) { inst_.push_back(std::make_pair(inst, args)); } // Gets the stack reference for the node value. If the node is a constant, // it may actually cause the push in to not be mentioned before. -int CompileGraph::Ref(const AnfNodePtr& node) { +int CompileGraph::Ref(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Start Ref node " << node->DebugString(true) << " height_: " << height_; if (slots_.count(node) == 0 && node->isa()) { @@ -176,7 +176,7 @@ int CompileGraph::Ref(const AnfNodePtr& node) { } // Make sure the value of node is at the top of the stack. -void CompileGraph::AddInput(const AnfNodePtr& node) { +void CompileGraph::AddInput(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (slots_.count(node) == 0) { MS_LOG(DEBUG) << "Input node is null " << node->DebugString(true); @@ -190,7 +190,7 @@ void CompileGraph::AddInput(const AnfNodePtr& node) { // Call back effect in stack void CompileGraph::Ret(int nargs) { set_height(height_ - nargs); } -void CompileGraph::PushParameters(const FuncGraphPtr& graph) { +void CompileGraph::PushParameters(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector parameters = graph->parameters(); for (size_t i = parameters.size(); i != 0; i--) { @@ -199,7 +199,7 @@ void CompileGraph::PushParameters(const FuncGraphPtr& graph) { } } -int CompileGraph::LinConvert(const FuncGraphPtr& graph, const AnfNodePtrList& node_list) { +int CompileGraph::LinConvert(const FuncGraphPtr &graph, const AnfNodePtrList &node_list) { MS_LOG(DEBUG) << "LinConvert start"; LinConvertResult result; @@ -227,14 +227,14 @@ int CompileGraph::LinConvert(const FuncGraphPtr& graph, const AnfNodePtrList& no } } AddExternal(result); - for (auto& o : result.outputs) { + for (auto &o : result.outputs) { Push(o); } return RET_SUCCESS; } -void CompileGraph::AddSinkSwitch(const CNodePtr& node) { +void CompileGraph::AddSinkSwitch(const CNodePtr &node) { MS_LOG(DEBUG) << "AddSinkSwitch:" << node->ToString(); if (backend_->is_multi_graph_sink()) { VectorRef args; @@ -255,7 +255,7 @@ void CompileGraph::AddSinkSwitch(const CNodePtr& node) { } } -int CompileGraph::InterpretNode(const FuncGraphPtr& graph, const CNodePtr& node) { +int CompileGraph::InterpretNode(const FuncGraphPtr &graph, const CNodePtr &node) { MS_EXCEPTION_IF_NULL(node); MS_LOG(DEBUG) << "Interpret node: " << node->DebugString(true); std::vector node_inputs = node->inputs(); @@ -293,7 +293,7 @@ int CompileGraph::InterpretNode(const FuncGraphPtr& graph, const CNodePtr& node) return RET_SUCCESS; } -void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr& graph) { +void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr &graph) { auto ret = LinConvert(graph, {}); if (ret == RET_FAILED) { MS_LOG(EXCEPTION) << "MultiGraphRun failed."; @@ -301,20 +301,20 @@ void CompileGraph::GenMultiGraphsRun(const FuncGraphPtr& graph) { AddReturn(nullptr); } -bool CompileGraph::SplitGraph(const FuncGraphPtr& graph) { +bool CompileGraph::SplitGraph(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start split graph"; MS_EXCEPTION_IF_NULL(graph); VectorRef splits = SplitNodes(graph); MS_LOG(DEBUG) << "Split nodes size:" << splits.size(); - for (auto& split : splits) { + for (auto &split : splits) { int ret = RET_SUCCESS; if (utils::isa(split)) { MS_LOG(DEBUG) << "Start a extern LinConvert"; std::vector args; auto vec_ref = utils::cast(split); (void)std::transform(vec_ref.begin(), vec_ref.end(), std::back_inserter(args), - [](const BaseRef& v) { return utils::cast(v); }); + [](const BaseRef &v) { return utils::cast(v); }); ret = LinConvert(graph, args); MS_LOG(DEBUG) << "End a extern LinConvert"; if (ret == RET_FAILED) { @@ -340,12 +340,12 @@ bool CompileGraph::SplitGraph(const FuncGraphPtr& graph) { return true; } -InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr& graph) { +InstSet CompileGraph::GenMultiGraphsSinkInst(const FuncGraphPtr &graph) { InstSet inst = Run(graph); return inst; } -InstSet CompileGraph::Run(const FuncGraphPtr& graph) { +InstSet CompileGraph::Run(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(DEBUG) << "Compile start graph: " << graph->ToString(); @@ -378,7 +378,7 @@ void CompileGraph::AddPadStack(int param_height) { } } -void CompileGraph::AddTailCall(const AnfNodePtr& fn, size_t size) { +void CompileGraph::AddTailCall(const AnfNodePtr &fn, size_t size) { VectorRef args; args.emplace_back(Ref(fn)); args.emplace_back(height_); @@ -387,7 +387,7 @@ void CompileGraph::AddTailCall(const AnfNodePtr& fn, size_t size) { AddInst(Instruction::kTailCall, args); } -void CompileGraph::AddPartial(const CNodePtr& node) { +void CompileGraph::AddPartial(const CNodePtr &node) { auto inputs = node->inputs(); VectorRef args; for (size_t i = 1; i < inputs.size(); i++) { @@ -396,7 +396,7 @@ void CompileGraph::AddPartial(const CNodePtr& node) { AddInst(Instruction::kPartial, args); } -void CompileGraph::AddMakeTuple(const CNodePtr& node) { +void CompileGraph::AddMakeTuple(const CNodePtr &node) { auto inputs = node->inputs(); VectorRef args; for (size_t i = 1; i < inputs.size(); i++) { @@ -405,7 +405,7 @@ void CompileGraph::AddMakeTuple(const CNodePtr& node) { AddInst(Instruction::kTuple, args); } -void CompileGraph::AddSwitch(const CNodePtr& node) { +void CompileGraph::AddSwitch(const CNodePtr &node) { auto inputs = node->inputs(); if (inputs.size() < 4) { MS_LOG(EXCEPTION) << "Length of inputs of primitive " << prim::kPrimSwitch->name() << " is less than 4"; @@ -420,7 +420,7 @@ void CompileGraph::AddSwitch(const CNodePtr& node) { AddInst(Instruction::kSwitch, args); } -void CompileGraph::AddReturn(const CNodePtr& node) { +void CompileGraph::AddReturn(const CNodePtr &node) { VectorRef args; if (backend_->simu_flag()) { args.emplace_back(Ref(backend_->final_output())); @@ -431,7 +431,7 @@ void CompileGraph::AddReturn(const CNodePtr& node) { AddInst(Instruction::kReturn, args); } -void CompileGraph::AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim) { +void CompileGraph::AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim) { auto inputs = node->inputs(); VectorRef args; args.push_back(prim); @@ -441,7 +441,7 @@ void CompileGraph::AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim) AddInst(Instruction::kPrim, args); } -int CompileGraph::AddCall(const FuncGraphPtr& graph, const CNodePtr& node) { +int CompileGraph::AddCall(const FuncGraphPtr &graph, const CNodePtr &node) { auto node_inputs = node->inputs(); AnfNodePtr fn = node_inputs[0]; (void)Ref(fn); @@ -459,7 +459,7 @@ int CompileGraph::AddCall(const FuncGraphPtr& graph, const CNodePtr& node) { return RET_SUCCESS; } -void CompileGraph::AddExternal(const LinConvertResult& result) { +void CompileGraph::AddExternal(const LinConvertResult &result) { VectorRef args; args.push_back(result.run); args.push_back(result.simu_run); @@ -471,16 +471,16 @@ void CompileGraph::AddExternal(const LinConvertResult& result) { } void TraverseGraphMap( - const FuncGraphManagerPtr& manager_ptr, FuncGraphTransaction* const tr, const FuncGraphToAnfNodeCounterMap& cts, - const std::function(const PrimitivePtr, const AbstractFunctionPtr)>& get_prim_graph) { + const FuncGraphManagerPtr &manager_ptr, FuncGraphTransaction *const tr, const FuncGraphToAnfNodeCounterMap &cts, + const std::function(const PrimitivePtr, const AbstractFunctionPtr)> &get_prim_graph) { MS_EXCEPTION_IF_NULL(manager_ptr); MS_EXCEPTION_IF_NULL(tr); - for (const auto& ct_graphs : cts) { - for (const auto& ct_any : ct_graphs.second) { + for (const auto &ct_graphs : cts) { + for (const auto &ct_any : ct_graphs.second) { AnfNodePtr const_primitive_node = ct_any.first; if (const_primitive_node != nullptr && IsValueNode(const_primitive_node)) { auto users = manager_ptr->node_users()[const_primitive_node]; - for (auto& use : users) { + for (auto &use : users) { CNodePtr node = use.first->cast(); MS_EXCEPTION_IF_NULL(node); int key = use.second; @@ -503,12 +503,12 @@ void TraverseGraphMap( } } -FuncGraphPtr WrapPrimitives(const FuncGraphPtr& graph) { +FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); FuncGraphManagerPtr manager_ptr = graph->manager(); MS_EXCEPTION_IF_NULL(manager_ptr); MapPrimTypeFuncGraph prim_graphs; - auto get_prim_graph = [&](const PrimitivePtr& prim, const AbstractFunctionPtr& type) { + auto get_prim_graph = [&](const PrimitivePtr &prim, const AbstractFunctionPtr &type) { PrimTypePair prim_type = std::make_pair(prim, type); if (prim_graphs.end() == prim_graphs.find(prim_type)) { FuncGraphPtr g = std::make_shared(); @@ -536,13 +536,13 @@ FuncGraphPtr WrapPrimitives(const FuncGraphPtr& graph) { }; FuncGraphTransaction tr = manager_ptr->Transact(); - auto& cts = manager_ptr->valuenodes(); + auto &cts = manager_ptr->valuenodes(); TraverseGraphMap(manager_ptr, &tr, cts, get_prim_graph); return graph; } -CompileGraphs::CompileGraphs(const BackendPtr& backend, const std::vector& cut_list) : backend_(backend) { +CompileGraphs::CompileGraphs(const BackendPtr &backend, const std::vector &cut_list) : backend_(backend) { MS_EXCEPTION_IF_NULL(backend); MS_LOG(DEBUG) << "Start vm: " << backend->name(); transform_ = std::make_shared(backend, cut_list); @@ -550,12 +550,12 @@ CompileGraphs::CompileGraphs(const BackendPtr& backend, const std::vectormanager(); MS_EXCEPTION_IF_NULL(graph_manager); FuncGraphSet graphs = graph_manager->func_graphs(); - for (auto& g : graphs) { + for (auto &g : graphs) { mapping_[g] = static_cast(insts_.size()); if (transform_ != nullptr) { InstSet insts = transform_->Run(g); @@ -568,7 +568,7 @@ void CompileGraphs::Compile(const FuncGraphPtr& graph) { } // Link instructions from multiple function graphs together. -FinalVMPtr CompileGraphs::Link(const FuncGraphPtr& graph) { +FinalVMPtr CompileGraphs::Link(const FuncGraphPtr &graph) { MS_LOG(DEBUG) << "Start"; for (std::size_t i = 0; i < insts_.size(); i++) { InstType inst = insts_[i]; @@ -600,7 +600,7 @@ FinalVMPtr CompileGraphs::Link(const FuncGraphPtr& graph) { } // Convert all graphs to unlinked instructions and link them. -FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr& graph) { +FinalVMPtr CompileGraphs::CompileAndLink(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); MS_LOG(DEBUG) << "Start"; Reset(); diff --git a/mindspore/ccsrc/vm/transform.h b/mindspore/ccsrc/vm/transform.h index 290af10049..711c1777ab 100644 --- a/mindspore/ccsrc/vm/transform.h +++ b/mindspore/ccsrc/vm/transform.h @@ -42,26 +42,26 @@ extern const char kGeVm[]; // A sub namespace in ME to support compile related definition. namespace compile { extern std::vector nonlinear_ops; -const std::vector& GetMsNonlinearOps(); +const std::vector &GetMsNonlinearOps(); -using VmEvalFunc = std::function; -using VmEvalFuncPtr = std::shared_ptr>; +using VmEvalFunc = std::function; +using VmEvalFuncPtr = std::shared_ptr>; class CompileGraph { public: - explicit CompileGraph(const BackendPtr& backend, const std::vector& cut_list = nonlinear_ops); + explicit CompileGraph(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); ~CompileGraph() = default; - InstSet Run(const FuncGraphPtr& func_graph); - InstSet GenMultiGraphsSinkInst(const FuncGraphPtr& graph); - bool IsCut(const AnfNodePtr& node); - void Push(const AnfNodePtr& node); - void Tie(const AnfNodePtr& n1, const AnfNodePtr& n2) { slots_[n2] = slots_[n1]; } + InstSet Run(const FuncGraphPtr &func_graph); + InstSet GenMultiGraphsSinkInst(const FuncGraphPtr &graph); + bool IsCut(const AnfNodePtr &node); + void Push(const AnfNodePtr &node); + void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } void Ret(int nargs); - void GenMultiGraphsRun(const FuncGraphPtr& graph); - int Ref(const AnfNodePtr& node); - VectorRef SplitNodes(const FuncGraphPtr& func_graph); + void GenMultiGraphsRun(const FuncGraphPtr &graph); + int Ref(const AnfNodePtr &node); + VectorRef SplitNodes(const FuncGraphPtr &func_graph); void set_height(int h) { height_ = h; @@ -78,24 +78,24 @@ class CompileGraph { } private: - void PushParameters(const FuncGraphPtr& func_graph); - bool SplitGraph(const FuncGraphPtr& func_graph); - int LinConvert(const FuncGraphPtr& func_graph, const AnfNodePtrList& node_list); - int InterpretNode(const FuncGraphPtr& func_graph, const CNodePtr& node); - int AddCall(const FuncGraphPtr& graph, const CNodePtr& node); - void AddSinkSwitch(const CNodePtr& node); + void PushParameters(const FuncGraphPtr &func_graph); + bool SplitGraph(const FuncGraphPtr &func_graph); + int LinConvert(const FuncGraphPtr &func_graph, const AnfNodePtrList &node_list); + int InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); + int AddCall(const FuncGraphPtr &graph, const CNodePtr &node); + void AddSinkSwitch(const CNodePtr &node); void AddPadStack(int param_height); - void AddTailCall(const AnfNodePtr& fn, size_t size); - void AddPartial(const CNodePtr& node); - void AddMakeTuple(const CNodePtr& node); - void AddSwitch(const CNodePtr& node); - void AddReturn(const CNodePtr& node); - void AddPrimitive(const CNodePtr& node, const PrimitivePtr& prim); - void AddInput(const AnfNodePtr& node); - void AddExternal(const LinConvertResult& result); - void AddInst(const Instruction& inst, const int& arg); - void AddInst(const Instruction& inst, const ValuePtr& arg); - void AddInst(const Instruction& inst, const VectorRef& args); + void AddTailCall(const AnfNodePtr &fn, size_t size); + void AddPartial(const CNodePtr &node); + void AddMakeTuple(const CNodePtr &node); + void AddSwitch(const CNodePtr &node); + void AddReturn(const CNodePtr &node); + void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); + void AddInput(const AnfNodePtr &node); + void AddExternal(const LinConvertResult &result); + void AddInst(const Instruction &inst, const int &arg); + void AddInst(const Instruction &inst, const ValuePtr &arg); + void AddInst(const Instruction &inst, const VectorRef &args); BackendPtr backend_; LinkFuncType lin_convert_; @@ -112,7 +112,7 @@ using CompileGraphPtr = std::shared_ptr; // CompileGraphs is used to Convert a graph cluster into instruction lists. class CompileGraphs { public: - explicit CompileGraphs(const BackendPtr& backend, const std::vector& cut_list = nonlinear_ops); + explicit CompileGraphs(const BackendPtr &backend, const std::vector &cut_list = nonlinear_ops); ~CompileGraphs() = default; @@ -121,9 +121,9 @@ class CompileGraphs { mapping_.clear(); } - void Compile(const FuncGraphPtr& func_graph); - FinalVMPtr Link(const FuncGraphPtr& func_graph); - FinalVMPtr CompileAndLink(const FuncGraphPtr& func_graph); + void Compile(const FuncGraphPtr &func_graph); + FinalVMPtr Link(const FuncGraphPtr &func_graph); + FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); private: InstSet insts_; diff --git a/mindspore/ccsrc/vm/vm.cc b/mindspore/ccsrc/vm/vm.cc index 493873b0bc..95ceceb67f 100644 --- a/mindspore/ccsrc/vm/vm.cc +++ b/mindspore/ccsrc/vm/vm.cc @@ -32,29 +32,29 @@ namespace compile { // Arguments: // fn_: Callable function. // args_: Sequence of function args. -StructPartial::StructPartial(int fn, const VectorRef& args) : fn_(fn), args_(args) {} +StructPartial::StructPartial(int fn, const VectorRef &args) : fn_(fn), args_(args) {} -std::ostream& operator<<(std::ostream& os, const StructPartial& other) { +std::ostream &operator<<(std::ostream &os, const StructPartial &other) { os << "partial(" << other.fn_ << ", " << other.args_.ToString() << ")"; return os; } -bool operator==(const StructPartial& lhs, const StructPartial& rhs) { +bool operator==(const StructPartial &lhs, const StructPartial &rhs) { return (lhs.fn_ == rhs.fn_ && lhs.args_ == rhs.args_); } -StructSimuSwitch::StructSimuSwitch(const BaseRef& fn, const BaseRef& value) : fn_(fn), value_(value) {} +StructSimuSwitch::StructSimuSwitch(const BaseRef &fn, const BaseRef &value) : fn_(fn), value_(value) {} -std::ostream& operator<<(std::ostream& os, const StructSimuSwitch& other) { +std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other) { os << "SimulSwitch(" << other.fn_.ToString() << ", " << other.value_.ToString() << ")"; return os; } -bool operator==(const StructSimuSwitch& lhs, const StructSimuSwitch& rhs) { +bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs) { return (lhs.fn_ == rhs.fn_ && lhs.value_ == rhs.value_); } -std::ostream& operator<<(std::ostream& os, const SwitchCondStatus& other) { +std::ostream &operator<<(std::ostream &os, const SwitchCondStatus &other) { os << "SwitchCondStatus(" << static_cast(other) << ")"; return os; } @@ -66,13 +66,13 @@ std::ostream& operator<<(std::ostream& os, const SwitchCondStatus& other) { // retp_: The call stack. // pc_: program counter (next instruction) // sp_: stack pointer (for the value stack) -FinalVM::FinalVM(const InstSet& insts, const BackendPtr& backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) { +FinalVM::FinalVM(const InstSet &insts, const BackendPtr &backend) : insts_(insts), pc_(0), sp_(0), backend_(backend) { MS_LOG(DEBUG) << "InstSet size:" << insts_.size(); insts_stack_.emplace_back(BaseRef()); retp_.push(-1); } -void FinalVM::Push(const BaseRef& v) { +void FinalVM::Push(const BaseRef &v) { MS_LOG(DEBUG) << "Push " << v.ToString() << " sp_:" << sp_; insts_stack_[IntToSize(sp_++)] = v; } @@ -140,7 +140,7 @@ void FinalVM::Popsp() { } } -void FinalVM::DoJmp(const BaseRef& jmp_orig) { +void FinalVM::DoJmp(const BaseRef &jmp_orig) { MS_LOG(DEBUG) << "Start"; BaseRef jmp = jmp_orig; @@ -173,7 +173,7 @@ void FinalVM::DoJmp(const BaseRef& jmp_orig) { MS_LOG(DEBUG) << "End do jump pc_:" << pc_; } -BaseRef FinalVM::Eval(const VectorRef& args) { +BaseRef FinalVM::Eval(const VectorRef &args) { MS_LOG(DEBUG) << "Start: " << args.size(); insts_stack_.clear(); insts_stack_.resize(args.size()); @@ -212,7 +212,7 @@ BaseRef FinalVM::Eval(const VectorRef& args) { return insts_stack_[0]; } -void FinalVM::InstCall(const VectorRef& args) { +void FinalVM::InstCall(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -228,7 +228,7 @@ void FinalVM::InstCall(const VectorRef& args) { MS_LOG(DEBUG) << "Instcall end sp :" << sp_; } -void FinalVM::InstTailCall(const VectorRef& args) { +void FinalVM::InstTailCall(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 3; if (args.size() != args_size) { @@ -258,7 +258,7 @@ void FinalVM::InstTailCall(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSwitchReturn(const VectorRef& args) { +void FinalVM::InstSwitchReturn(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; if (args.size() != 1) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires one parameter, while the input size is " << args.size() << "."; @@ -268,7 +268,7 @@ void FinalVM::InstSwitchReturn(const VectorRef& args) { Popsp(); } -void FinalVM::InstReturn(const VectorRef& args) { +void FinalVM::InstReturn(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 2; if (args.size() != args_size) { @@ -291,7 +291,7 @@ void FinalVM::InstReturn(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPartial(const VectorRef& args) { +void FinalVM::InstPartial(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() < args_size) { @@ -306,12 +306,12 @@ void FinalVM::InstPartial(const VectorRef& args) { std::vector outs(args.size() - 1); (void)std::transform(args.begin() + 1, args.end(), outs.begin(), - [&, this](const BaseRef& a) { return Ref(utils::cast(a)); }); + [&, this](const BaseRef &a) { return Ref(utils::cast(a)); }); Push(std::make_shared(fn, VectorRef(outs))); MS_LOG(DEBUG) << "End"; } -void FinalVM::InstSimuSwitch(const VectorRef& args) { +void FinalVM::InstSimuSwitch(const VectorRef &args) { const size_t args_size = 4; if (args.size() != args_size) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " @@ -365,7 +365,7 @@ void FinalVM::InstSimuSwitch(const VectorRef& args) { } } -void FinalVM::InstRealSwitch(const VectorRef& args) { +void FinalVM::InstRealSwitch(const VectorRef &args) { const size_t args_size = 3; if (args.size() != args_size) { MS_LOG(ERROR) << "" << __FUNCTION__ << " requires " << args_size << " parameters, while the input size is " @@ -392,7 +392,7 @@ void FinalVM::InstRealSwitch(const VectorRef& args) { } } -void FinalVM::InstSwitch(const VectorRef& args) { +void FinalVM::InstSwitch(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; if (backend_->is_multi_graph_sink()) { InstSimuSwitch(args); @@ -401,7 +401,7 @@ void FinalVM::InstSwitch(const VectorRef& args) { } } -void FinalVM::InstTuple(const VectorRef& args) { +void FinalVM::InstTuple(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; VectorRef tuple; auto iter = args.begin(); @@ -413,7 +413,7 @@ void FinalVM::InstTuple(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPush(const VectorRef& args) { +void FinalVM::InstPush(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -427,7 +427,7 @@ void FinalVM::InstPush(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstInput(const VectorRef& args) { +void FinalVM::InstInput(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -441,7 +441,7 @@ void FinalVM::InstInput(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPadStack(const VectorRef& args) { +void FinalVM::InstPadStack(const VectorRef &args) { MS_LOG(DEBUG) << "Start"; const size_t args_size = 1; if (args.size() != args_size) { @@ -461,7 +461,7 @@ void FinalVM::InstPadStack(const VectorRef& args) { MS_LOG(DEBUG) << "End"; } -void FinalVM::InstExternal(const VectorRef& args) { +void FinalVM::InstExternal(const VectorRef &args) { MS_LOG(DEBUG) << "Start:" << args.size(); if (args.empty()) { @@ -490,14 +490,14 @@ void FinalVM::InstExternal(const VectorRef& args) { auto outs = (*fn)(tuple); MS_LOG(DEBUG) << "'fn' out size:" << outs.size(); - for (auto& o : outs) { + for (auto &o : outs) { MS_LOG(DEBUG) << "InstExternal value:" << o.ToString(); Push(o); } MS_LOG(DEBUG) << "End"; } -void FinalVM::InstPushPrim(const VectorRef& args) { +void FinalVM::InstPushPrim(const VectorRef &args) { MS_LOG(DEBUG) << "Start: " << args.size(); const size_t args_size = 2; if (args.size() < args_size) { diff --git a/mindspore/ccsrc/vm/vm.h b/mindspore/ccsrc/vm/vm.h index 3e1e5b5c08..eab726a9b7 100644 --- a/mindspore/ccsrc/vm/vm.h +++ b/mindspore/ccsrc/vm/vm.h @@ -53,14 +53,14 @@ enum Instruction { using InstType = std::pair; using InstSet = std::vector; -using InstFunctionMap = std::map>; +using InstFunctionMap = std::map>; const std::vector inst_str{"call", "tail_call", "return", "partial", "switch", "switch_return", "tuple", "input", "external", "push", "primitive", "graph", "pad_stack"}; class StructPartial : public Base { public: // Initialize StructPartial. - StructPartial(int fn, const VectorRef& args); + StructPartial(int fn, const VectorRef &args); virtual ~StructPartial() = default; MS_DECLARE_PARENT(StructPartial, Base) @@ -69,12 +69,12 @@ class StructPartial : public Base { VectorRef args_; }; -std::ostream& operator<<(std::ostream& os, const StructPartial& other); -bool operator==(const StructPartial& lhs, const StructPartial& rhs); +std::ostream &operator<<(std::ostream &os, const StructPartial &other); +bool operator==(const StructPartial &lhs, const StructPartial &rhs); class StructSimuSwitch : public Base { public: - StructSimuSwitch(const BaseRef& fn, const BaseRef& value); + StructSimuSwitch(const BaseRef &fn, const BaseRef &value); virtual ~StructSimuSwitch() = default; MS_DECLARE_PARENT(StructSimuSwitch, Base) @@ -83,43 +83,43 @@ class StructSimuSwitch : public Base { BaseRef value_; }; -std::ostream& operator<<(std::ostream& os, const StructSimuSwitch& other); -bool operator==(const StructSimuSwitch& lhs, const StructSimuSwitch& rhs); +std::ostream &operator<<(std::ostream &os, const StructSimuSwitch &other); +bool operator==(const StructSimuSwitch &lhs, const StructSimuSwitch &rhs); class FinalVM { public: // Create a VM with the specified instructions and backend. - explicit FinalVM(const InstSet& insts, const BackendPtr& backend); + explicit FinalVM(const InstSet &insts, const BackendPtr &backend); virtual ~FinalVM() = default; - BaseRef Eval(const VectorRef& args); - void InstCall(const VectorRef& args); - void InstTailCall(const VectorRef& args); - void InstReturn(const VectorRef& args); - void InstPartial(const VectorRef& args); - void InstSwitch(const VectorRef& args); - void InstSimuSwitch(const VectorRef& args); - void InstRealSwitch(const VectorRef& args); - void InstTuple(const VectorRef& args); - void InstPush(const VectorRef& args); - void InstInput(const VectorRef& args); - void InstPadStack(const VectorRef& args); - void InstExternal(const VectorRef& args); - void InstPushPrim(const VectorRef& args); - void InstSwitchReturn(const VectorRef& args); - void set_insts(const InstSet& value) { insts_ = value; } + BaseRef Eval(const VectorRef &args); + void InstCall(const VectorRef &args); + void InstTailCall(const VectorRef &args); + void InstReturn(const VectorRef &args); + void InstPartial(const VectorRef &args); + void InstSwitch(const VectorRef &args); + void InstSimuSwitch(const VectorRef &args); + void InstRealSwitch(const VectorRef &args); + void InstTuple(const VectorRef &args); + void InstPush(const VectorRef &args); + void InstInput(const VectorRef &args); + void InstPadStack(const VectorRef &args); + void InstExternal(const VectorRef &args); + void InstPushPrim(const VectorRef &args); + void InstSwitchReturn(const VectorRef &args); + void set_insts(const InstSet &value) { insts_ = value; } protected: BaseRef Ref(int i); - void Push(const BaseRef& v); + void Push(const BaseRef &v); void Pop(int n = 1); void MoveStack(int nitems, int height); void Pushp(); void Popp(); void Pushsp(); void Popsp(); - void DoJmp(const BaseRef& jmp); + void DoJmp(const BaseRef &jmp); private: InstSet insts_; @@ -130,18 +130,18 @@ class FinalVM { int sp_; BackendPtr backend_; const InstFunctionMap inst_function_map = { - {Instruction::kCall, [this](const VectorRef& args) { InstCall(args); }}, - {Instruction::kTailCall, [this](const VectorRef& args) { InstTailCall(args); }}, - {Instruction::kReturn, [this](const VectorRef& args) { InstReturn(args); }}, - {Instruction::kPartial, [this](const VectorRef& args) { InstPartial(args); }}, - {Instruction::kSwitch, [this](const VectorRef& args) { InstSwitch(args); }}, - {Instruction::kTuple, [this](const VectorRef& args) { InstTuple(args); }}, - {Instruction::kPush, [this](const VectorRef& args) { InstPush(args); }}, - {Instruction::kInput, [this](const VectorRef& args) { InstInput(args); }}, - {Instruction::kPadStack, [this](const VectorRef& args) { InstPadStack(args); }}, - {Instruction::kExternal, [this](const VectorRef& args) { InstExternal(args); }}, - {Instruction::kPrim, [this](const VectorRef& args) { InstPushPrim(args); }}, - {Instruction::kSwitchReturn, [this](const VectorRef& args) { InstSwitchReturn(args); }}, + {Instruction::kCall, [this](const VectorRef &args) { InstCall(args); }}, + {Instruction::kTailCall, [this](const VectorRef &args) { InstTailCall(args); }}, + {Instruction::kReturn, [this](const VectorRef &args) { InstReturn(args); }}, + {Instruction::kPartial, [this](const VectorRef &args) { InstPartial(args); }}, + {Instruction::kSwitch, [this](const VectorRef &args) { InstSwitch(args); }}, + {Instruction::kTuple, [this](const VectorRef &args) { InstTuple(args); }}, + {Instruction::kPush, [this](const VectorRef &args) { InstPush(args); }}, + {Instruction::kInput, [this](const VectorRef &args) { InstInput(args); }}, + {Instruction::kPadStack, [this](const VectorRef &args) { InstPadStack(args); }}, + {Instruction::kExternal, [this](const VectorRef &args) { InstExternal(args); }}, + {Instruction::kPrim, [this](const VectorRef &args) { InstPushPrim(args); }}, + {Instruction::kSwitchReturn, [this](const VectorRef &args) { InstSwitchReturn(args); }}, }; }; diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index ee9a817dd8..017121f334 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -40,25 +40,25 @@ using PrimitivePyPtr = std::shared_ptr; // Indicate a call to a new frame. struct CallWrap : public Base { - explicit CallWrap(const VMFramePtr& vm_frame) : frame(vm_frame) {} + explicit CallWrap(const VMFramePtr &vm_frame) : frame(vm_frame) {} VMFramePtr frame{nullptr}; }; using CallWrapPtr = std::shared_ptr; // Indicates a return with its value. struct ReturnWrap : public Base { - explicit ReturnWrap(const BaseRef& r_value) : value(r_value) {} + explicit ReturnWrap(const BaseRef &r_value) : value(r_value) {} BaseRef value{BaseRef()}; }; using ReturnWrapPtr = std::shared_ptr; -VMFrame::VMFrame(const AnfNodePtrList& nodes, const AnfNodePtrToBaseRefMap& values, - const AnfNodePtrToBaseRefMap& closure) +VMFrame::VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, + const AnfNodePtrToBaseRefMap &closure) : values_(values), todo_(nodes), closure_(closure) { std::reverse(std::begin(todo_), std::end(todo_)); } -const BaseRef VMFrame::operator[](const AnfNodePtr& node) { +const BaseRef VMFrame::operator[](const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); auto ret = values_.find(node); if (ret != values_.end()) { @@ -77,31 +77,31 @@ const BaseRef VMFrame::operator[](const AnfNodePtr& node) { MS_LOG(EXCEPTION) << "ValueError " << node->type_name(); } -Closure::Closure(const FuncGraphPtr& graph, const AnfNodePtrToBaseRefMap& values) +Closure::Closure(const FuncGraphPtr &graph, const AnfNodePtrToBaseRefMap &values) : func_graph_(graph), values_(values) {} -BaseRef Closure::operator()(const VectorRef& args) { +BaseRef Closure::operator()(const VectorRef &args) { MS_LOG(DEBUG) << "start closure"; return vm_->Evaluate(func_graph_, args, values_); } -Partial::Partial(const BaseRef& fn, const VectorRef& args, const VMPtr& vm) : fn_(fn), args_(args), vm_(vm) {} +Partial::Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm) : fn_(fn), args_(args), vm_(vm) {} -BaseRef Partial::operator()(const VectorRef& nodes) { +BaseRef Partial::operator()(const VectorRef &nodes) { VectorRef arglist; (void)arglist.insert(arglist.end(), args_.begin(), args_.end()); (void)arglist.insert(arglist.end(), nodes.begin(), nodes.end()); return vm_->Call(fn_, arglist); } -SetRef VM::ComputeFvs(const FuncGraphPtr& graph) { +SetRef VM::ComputeFvs(const FuncGraphPtr &graph) { MS_EXCEPTION_IF_NULL(graph); SetRef rval; - for (auto& fkv : graph->free_variables_total()) { + for (auto &fkv : graph->free_variables_total()) { if (utils::isa(fkv.first)) { // Add all value_nodes of g that refer to a fv graph auto g = utils::cast(fkv.first); - for (auto& ctkv : g->value_nodes()) { + for (auto &ctkv : g->value_nodes()) { auto ct = ctkv.first; if (GetValueNode(ct) == g) { (void)rval.insert(ct); @@ -116,7 +116,7 @@ SetRef VM::ComputeFvs(const FuncGraphPtr& graph) { return rval; } -void VM::AcquireGraph(const FuncGraphPtr& graph) { +void VM::AcquireGraph(const FuncGraphPtr &graph) { // Already acquired if (vars_.find(graph) != vars_.end()) { return; @@ -130,30 +130,30 @@ void VM::AcquireGraph(const FuncGraphPtr& graph) { } } -VectorRef VM::ExportSequence(const VectorRef& seq) { +VectorRef VM::ExportSequence(const VectorRef &seq) { std::vector ret; (void)std::transform(std::begin(seq), std::end(seq), std::back_inserter(ret), - [&, this](const BaseRef& x) -> BaseRef { return Export(x); }); + [&, this](const BaseRef &x) -> BaseRef { return Export(x); }); return VectorRef(ret); } -ClosurePtr VM::ExportClosure(const ClosurePtr& clos) { +ClosurePtr VM::ExportClosure(const ClosurePtr &clos) { MS_EXCEPTION_IF_NULL(clos); clos->set_vm(shared_from_this()); return clos; } // transform graph to executable closure -ClosurePtr VM::ExportGraph(const FuncGraphPtr& g) { +ClosurePtr VM::ExportGraph(const FuncGraphPtr &g) { auto c = std::make_shared(g, AnfNodePtrToBaseRefMap()); MS_EXCEPTION_IF_NULL(c); c->set_vm(shared_from_this()); return c; } -BaseRef VM::ExportObj(const BaseRef& obj) const { return obj; } +BaseRef VM::ExportObj(const BaseRef &obj) const { return obj; } -BaseRef VM::Export(const BaseRef& value) { +BaseRef VM::Export(const BaseRef &value) { if (utils::isa(value) && utils::cast(value)->isa()) { return ExportGraph(utils::cast(value)->cast()); } @@ -183,7 +183,7 @@ BaseRef VM::Export(const BaseRef& value) { // Run a graph. // This will evaluate the passed-in graph and return the resulting value. -BaseRef VM::Evaluate(const FuncGraphPtr& graph, const VectorRef& args, const AnfNodePtrToBaseRefMap& closure) { +BaseRef VM::Evaluate(const FuncGraphPtr &graph, const VectorRef &args, const AnfNodePtrToBaseRefMap &closure) { AcquireGraph(graph); MS_LOG(DEBUG) << "evalue arg size: " << args.size(); if (args.size() != graph->parameters().size()) { @@ -237,15 +237,15 @@ BaseRef VM::Evaluate(const FuncGraphPtr& graph, const VectorRef& args, const Anf MS_LOG(EXCEPTION) << "VM Evaluate error"; } -SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { - auto fn = [&, this](const AnfNodePtr& node) -> AnfNodePtrList { +SuccFunc VM::SuccVm(const FuncGraphPtr &graph) { + auto fn = [&, this](const AnfNodePtr &node) -> AnfNodePtrList { MS_EXCEPTION_IF_NULL(node); AnfNodePtrList ret; // Follow node.incoming if (node->isa()) { - auto& inputs = node->cast()->inputs(); - for (auto& i : inputs) { + auto &inputs = node->cast()->inputs(); + for (auto &i : inputs) { if (i->func_graph() == node->func_graph() || (IsValueNode(i) && GetValueNode(i)->parent() == graph)) { ret.push_back(i); @@ -257,7 +257,7 @@ SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { if (IsValueNode(node) && GetValueNode(node)->parent() == graph) { auto fvs = utils::cast(vars_[GetValueNode(node)]); (void)std::transform(fvs.begin(), fvs.end(), std::back_inserter(ret), - [](const BaseRef& value) -> AnfNodePtr { return utils::cast(value); }); + [](const BaseRef &value) -> AnfNodePtr { return utils::cast(value); }); } return ret; @@ -265,7 +265,7 @@ SuccFunc VM::SuccVm(const FuncGraphPtr& graph) { return fn; } -BaseRef VM::Call(const BaseRef& fn, const VectorRef& args) { +BaseRef VM::Call(const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn)) { return RunOperation(utils::cast(fn), args); } @@ -283,7 +283,7 @@ BaseRef VM::Call(const BaseRef& fn, const VectorRef& args) { } // make call frame for graph -BaseRef VM::_Call(const BaseRef& graph, const VectorRef& args) { +BaseRef VM::_Call(const BaseRef &graph, const VectorRef &args) { AnfNodePtrToBaseRefMap clos; auto func_graph = graph; if (utils::isa(func_graph)) { @@ -319,11 +319,11 @@ BaseRef VM::_Call(const BaseRef& graph, const VectorRef& args) { } // make closure out of graph with fv values from frame -ClosurePtr VM::MakeClosure(const FuncGraphPtr& graph, const VMFramePtr& frame) { +ClosurePtr VM::MakeClosure(const FuncGraphPtr &graph, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(frame); AnfNodePtrToBaseRefMap clos; - for (auto& v : utils::cast(vars_[graph])) { + for (auto &v : utils::cast(vars_[graph])) { auto anf = utils::cast(v); clos[anf] = (*frame)[anf]; } @@ -331,7 +331,7 @@ ClosurePtr VM::MakeClosure(const FuncGraphPtr& graph, const VMFramePtr& frame) { return std::make_shared(graph, clos); } -BaseRef VM::DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const BaseRef& fn, const VectorRef& args) { +BaseRef VM::DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args) { if (utils::isa(fn) && utils::cast(fn)->isa()) { auto fnval = utils::cast(fn)->cast(); MS_LOG(DEBUG) << "DispatchCall prim:" << fnval->name() << ", node:" << node->DebugString(true); @@ -384,7 +384,7 @@ BaseRef VM::DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const MS_LOG(EXCEPTION) << "Invalid fn to call"; } -BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { +BaseRef VM::HandleNode(const AnfNodePtr &node, const VMFramePtr &frame) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { // pass @@ -409,10 +409,10 @@ BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { if (node->isa()) { std::vector fnArgs; - auto& inputs = node->cast()->inputs(); + auto &inputs = node->cast()->inputs(); // set args' values in frame (void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(fnArgs), - [&](const AnfNodePtr& inp) -> BaseRef { return (*frame)[inp]; }); + [&](const AnfNodePtr &inp) -> BaseRef { return (*frame)[inp]; }); if (fnArgs.empty()) { MS_LOG(EXCEPTION) << "function arguments is empty"; } else { @@ -425,7 +425,7 @@ BaseRef VM::HandleNode(const AnfNodePtr& node, const VMFramePtr& frame) { MS_LOG(EXCEPTION) << "Unknown node type"; } -VectorRef VM::RunGraph(const FuncGraphPtr& g, const VectorRef& args) { +VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) { this->manager_ = Manage(g); auto fn = utils::cast(Export(g)); @@ -439,7 +439,7 @@ VectorRef VM::RunGraph(const FuncGraphPtr& g, const VectorRef& args) { } } -BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args) { +BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { PrimitivePyPtr operation = dyn_cast(prim); MS_LOG(DEBUG) << "operation start " << prim->name(); @@ -451,7 +451,7 @@ BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args) { py::tuple py_args = py::tuple(args.size()); MS_LOG(DEBUG) << "input for operation:"; size_t i = 0; - for (auto& arg : args) { + for (auto &arg : args) { py_args[i] = BaseRefToPyData(arg); MS_LOG(DEBUG) << "arg: " << i << ":"; i++; diff --git a/mindspore/ccsrc/vm/vmimpl.h b/mindspore/ccsrc/vm/vmimpl.h index 4ef507af82..11d026fe72 100644 --- a/mindspore/ccsrc/vm/vmimpl.h +++ b/mindspore/ccsrc/vm/vmimpl.h @@ -53,14 +53,14 @@ using VMPtr = std::shared_ptr; class Partial; using PartialPtr = std::shared_ptr; -using RunFunc = std::function; +using RunFunc = std::function; using RunFuncPtr = std::shared_ptr; using SuccFunc = std::function; class VMImpl { public: - virtual VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) = 0; + virtual VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) = 0; virtual ~VMImpl() = default; }; @@ -76,11 +76,11 @@ class VMImpl { // closure: values for the closure if the current application is a closure class VMFrame { public: - VMFrame(const AnfNodePtrList& nodes, const AnfNodePtrToBaseRefMap& values, const AnfNodePtrToBaseRefMap& closure); - const BaseRef operator[](const AnfNodePtr& node); - const AnfNodePtrList& todo() const { return todo_; } + VMFrame(const AnfNodePtrList &nodes, const AnfNodePtrToBaseRefMap &values, const AnfNodePtrToBaseRefMap &closure); + const BaseRef operator[](const AnfNodePtr &node); + const AnfNodePtrList &todo() const { return todo_; } - AnfNodePtrToBaseRefMap& values() { return values_; } + AnfNodePtrToBaseRefMap &values() { return values_; } virtual ~VMFrame() = default; @@ -94,16 +94,16 @@ class VMFrame { // Representation of a closure. class Closure : public Base { public: - Closure(const FuncGraphPtr& func_graph, const AnfNodePtrToBaseRefMap& values); - BaseRef operator()(const VectorRef& args); + Closure(const FuncGraphPtr &func_graph, const AnfNodePtrToBaseRefMap &values); + BaseRef operator()(const VectorRef &args); - const VMPtr& vm() const { return vm_; } + const VMPtr &vm() const { return vm_; } - void set_vm(const VMPtr& vm) { vm_ = vm; } + void set_vm(const VMPtr &vm) { vm_ = vm; } - const FuncGraphPtr& func_graph() const { return func_graph_; } + const FuncGraphPtr &func_graph() const { return func_graph_; } - const AnfNodePtrToBaseRefMap& values() const { return values_; } + const AnfNodePtrToBaseRefMap &values() const { return values_; } virtual ~Closure() = default; @@ -118,11 +118,11 @@ class Closure : public Base { // Representation of a partial application. class Partial : public Base { public: - Partial(const BaseRef& fn, const VectorRef& args, const VMPtr& vm); - BaseRef operator()(const VectorRef& nodes); - const BaseRef& fn() const { return fn_; } + Partial(const BaseRef &fn, const VectorRef &args, const VMPtr &vm); + BaseRef operator()(const VectorRef &nodes); + const BaseRef &fn() const { return fn_; } - const VectorRef& args() const { return args_; } + const VectorRef &args() const { return args_; } virtual ~Partial() = default; MS_DECLARE_PARENT(Partial, Base) @@ -136,52 +136,52 @@ class Partial : public Base { // Virtual Machine interface. class VM : public std::enable_shared_from_this, public VMImpl { public: - SetRef ComputeFvs(const FuncGraphPtr& func_graph); + SetRef ComputeFvs(const FuncGraphPtr &func_graph); - void AcquireGraph(const FuncGraphPtr& func_graph); + void AcquireGraph(const FuncGraphPtr &func_graph); - VectorRef ExportSequence(const VectorRef& seq); + VectorRef ExportSequence(const VectorRef &seq); - BaseRef ExportPrimitive(const PrimitivePtr&) const { return kAnyValue; } + BaseRef ExportPrimitive(const PrimitivePtr &) const { return kAnyValue; } - ClosurePtr ExportClosure(const ClosurePtr& clos); + ClosurePtr ExportClosure(const ClosurePtr &clos); // Return an object that executes `fg` when called on arguments. - ClosurePtr ExportGraph(const FuncGraphPtr& fg); + ClosurePtr ExportGraph(const FuncGraphPtr &fg); - BaseRef ExportObj(const BaseRef& obj) const; + BaseRef ExportObj(const BaseRef &obj) const; - BaseRef Export(const BaseRef& value); + BaseRef Export(const BaseRef &value); // Run a graph. // This will evaluate the passed-in graph and return the // resulting value. - BaseRef Evaluate(const FuncGraphPtr& func_graph, const VectorRef& args, - const AnfNodePtrToBaseRefMap& closure = AnfNodePtrToBaseRefMap()); + BaseRef Evaluate(const FuncGraphPtr &func_graph, const VectorRef &args, + const AnfNodePtrToBaseRefMap &closure = AnfNodePtrToBaseRefMap()); // Return a visitor for the graph. - SuccFunc SuccVm(const FuncGraphPtr& func_graph); + SuccFunc SuccVm(const FuncGraphPtr &func_graph); // Call the `fn` object. // `fn` can be anything that would be valid as the first element of an apply. - BaseRef Call(const BaseRef& fn, const VectorRef& args); + BaseRef Call(const BaseRef &fn, const VectorRef &args); - BaseRef _Call(const BaseRef& graph, const VectorRef& args); + BaseRef _Call(const BaseRef &graph, const VectorRef &args); - ClosurePtr MakeClosure(const FuncGraphPtr& func_graph, const VMFramePtr& frame); + ClosurePtr MakeClosure(const FuncGraphPtr &func_graph, const VMFramePtr &frame); - BaseRef DispatchCall(const AnfNodePtr& node, const VMFramePtr& frame, const BaseRef& fn, const VectorRef& args); + BaseRef DispatchCall(const AnfNodePtr &node, const VMFramePtr &frame, const BaseRef &fn, const VectorRef &args); - BaseRef HandleNode(const AnfNodePtr& node, const VMFramePtr& frame); + BaseRef HandleNode(const AnfNodePtr &node, const VMFramePtr &frame); - VectorRef RunGraph(const FuncGraphPtr& fg, const VectorRef& args) override; + VectorRef RunGraph(const FuncGraphPtr &fg, const VectorRef &args) override; private: FuncGraphManagerPtr manager_; FuncGraphPtrToBaseRefMap vars_; }; -extern BaseRef RunOperation(const PrimitivePtr& prim, const VectorRef& args); +extern BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args); } // namespace compile } // namespace mindspore From 108eeb8e3d3efdbaf03b3e0312d9d381cea1fa8c Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Mon, 20 Apr 2020 12:49:05 +0800 Subject: [PATCH 065/142] fix voc test cases --- mindspore/dataset/engine/datasets.py | 38 ++++++++----------- .../dataset/engine/serializer_deserializer.py | 3 +- mindspore/dataset/engine/validators.py | 5 +-- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 8de56a6dff..5fb4b2537f 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -2837,14 +2837,17 @@ class VOCDataset(SourceDataset): decode (bool, optional): Decode the images after reading (default=False). sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, expected order behavior shown in the table). - distribution (str, optional): Path to the json distribution file to configure - dataset sharding (default=None). This argument should be specified - only when no 'sampler' is used. + num_shards (int, optional): Number of shards that the dataset should be divided + into (default=None). + shard_id (int, optional): The shard ID within num_shards (default=None). This + argument should be specified only when num_shards is also specified. Raises: - RuntimeError: If distribution and sampler are specified at the same time. - RuntimeError: If distribution is failed to read. - RuntimeError: If shuffle and sampler are specified at the same time. + RuntimeError: If sampler and shuffle are specified at the same time. + RuntimeError: If sampler and sharding are specified at the same time. + RuntimeError: If num_shards is specified but shard_id is None. + RuntimeError: If shard_id is specified but num_shards is None. + ValueError: If shard_id is invalid (< 0 or >= num_shards). Examples: >>> import mindspore.dataset as ds @@ -2858,27 +2861,15 @@ class VOCDataset(SourceDataset): @check_vocdataset def __init__(self, dataset_dir, num_samples=None, num_parallel_workers=None, - shuffle=None, decode=False, sampler=None, distribution=None): + shuffle=None, decode=False, sampler=None, num_shards=None, shard_id=None): super().__init__(num_parallel_workers) self.dataset_dir = dataset_dir - self.sampler = sampler - if distribution is not None: - if sampler is not None: - raise RuntimeError("Cannot specify distribution and sampler at the same time.") - try: - with open(distribution, 'r') as load_d: - json.load(load_d) - except json.decoder.JSONDecodeError: - raise RuntimeError("Json decode error when load distribution file") - except Exception: - raise RuntimeError("Distribution file has failed to load.") - elif shuffle is not None: - if sampler is not None: - raise RuntimeError("Cannot specify shuffle and sampler at the same time.") + self.sampler = _select_sampler(num_samples, sampler, shuffle, num_shards, shard_id) self.num_samples = num_samples self.decode = decode - self.distribution = distribution self.shuffle_level = shuffle + self.num_shards = num_shards + self.shard_id = shard_id def get_args(self): args = super().get_args() @@ -2887,7 +2878,8 @@ class VOCDataset(SourceDataset): args["sampler"] = self.sampler args["decode"] = self.decode args["shuffle"] = self.shuffle_level - args["distribution"] = self.distribution + args["num_shards"] = self.num_shards + args["shard_id"] = self.shard_id return args def get_dataset_size(self): diff --git a/mindspore/dataset/engine/serializer_deserializer.py b/mindspore/dataset/engine/serializer_deserializer.py index 61417e4d52..f588d572bb 100644 --- a/mindspore/dataset/engine/serializer_deserializer.py +++ b/mindspore/dataset/engine/serializer_deserializer.py @@ -286,7 +286,8 @@ def create_node(node): elif dataset_op == 'VOCDataset': sampler = construct_sampler(node.get('sampler')) pyobj = pyclass(node['dataset_dir'], node.get('num_samples'), node.get('num_parallel_workers'), - node.get('shuffle'), node.get('decode'), sampler, node.get('distribution')) + node.get('shuffle'), node.get('decode'), sampler, node.get('num_shards'), + node.get('shard_id')) elif dataset_op == 'CelebADataset': sampler = construct_sampler(node.get('sampler')) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index b74e913202..34981c5218 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -445,9 +445,8 @@ def check_vocdataset(method): def new_method(*args, **kwargs): param_dict = make_param_dict(method, args, kwargs) - nreq_param_int = ['num_samples', 'num_parallel_workers'] + nreq_param_int = ['num_samples', 'num_parallel_workers', 'num_shards', 'shard_id'] nreq_param_bool = ['shuffle', 'decode'] - nreq_param_str = ['distribution'] # check dataset_dir; required argument dataset_dir = param_dict.get('dataset_dir') @@ -459,7 +458,7 @@ def check_vocdataset(method): check_param_type(nreq_param_bool, param_dict, bool) - check_param_type(nreq_param_str, param_dict, str) + check_sampler_shuffle_shard_options(param_dict) return method(*args, **kwargs) From 422bc304df5b690e217b473b7d968f233290fc60 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 09:32:37 -0400 Subject: [PATCH 066/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index fef9494ea4..1628f8d1c3 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -14,6 +14,7 @@ # ============================================================================ """pooling""" from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell @@ -272,6 +273,17 @@ class AvgPool1d(_PoolNd): self.avg_pool = P.AvgPool(ksize=self.kernel_size, strides=self.stride, padding=self.pad_mode) + self.shape = F.shape + self.reduce_mean = P.ReduceMean(keep_dims=True) + self.slice = P.Slice() def construct(self, x): - return self.avg_pool(x) + batch, channel, high, width = self.shape(x) + if width == self.kernel_size[1]: + x = self.reduce_mean(x, 3) + elif width - self.kernel_size[1] < self.stride[1]: + x = self.slice(x, (0, 0, 0, 0), (batch, channel, high, self.kernel_size[1])) + x = self.reduce_mean(x, 3) + else: + x = self.avg_pool(x) + return x From e1b6addefd6afd4c71fade2f13d8d8bbdf560af6 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 09:48:35 -0400 Subject: [PATCH 067/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 1628f8d1c3..359e75a4c0 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -18,6 +18,7 @@ from mindspore.ops import functional as F from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell +from ..._checkparam import Rel class _PoolNd(Cell): @@ -263,10 +264,15 @@ class AvgPool1d(_PoolNd): stride=1, pad_mode="valid"): super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) + validator.check_type('kernel_size', kernel_size, [int,]) + validator.check_type('stride', stride, [int,]) + self.padding = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) if not isinstance(kernel_size, int): + validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) raise ValueError("kernel_size should be 1 int number but got {}". format(kernel_size)) if not isinstance(stride, int): + validator.check_integer("stride", stride, 1, Rel.GE) raise ValueError("stride should be 1 int number but got {}".format(stride)) self.kernel_size = (1, kernel_size) self.stride = (1, stride) From 5d289ef029a40cdbb103e5ac6cffeec0d8d99bec Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 10:10:49 -0400 Subject: [PATCH 068/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 359e75a4c0..a19ef06b7e 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -19,6 +19,7 @@ from mindspore._checkparam import Validator as validator from ... import context from ..cell import Cell from ..._checkparam import Rel +from ..._checkparam import ParamValidator class _PoolNd(Cell): @@ -264,15 +265,15 @@ class AvgPool1d(_PoolNd): stride=1, pad_mode="valid"): super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode) - validator.check_type('kernel_size', kernel_size, [int,]) - validator.check_type('stride', stride, [int,]) - self.padding = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) + ParamValidator.check_type('kernel_size', kernel_size, [int,]) + ParamValidator.check_type('stride', stride, [int,]) + self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) if not isinstance(kernel_size, int): - validator.check_integer("kernel_size", kernel_size, 1, Rel.GE) + ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) raise ValueError("kernel_size should be 1 int number but got {}". format(kernel_size)) if not isinstance(stride, int): - validator.check_integer("stride", stride, 1, Rel.GE) + ParamValidator.check_integer("stride", stride, 1, Rel.GE) raise ValueError("stride should be 1 int number but got {}".format(stride)) self.kernel_size = (1, kernel_size) self.stride = (1, stride) From 227da6e7205e0470ebf377d1b7403e8c94ea1f7a Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 10:27:26 -0400 Subject: [PATCH 069/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index a19ef06b7e..28826c88bb 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -268,13 +268,8 @@ class AvgPool1d(_PoolNd): ParamValidator.check_type('kernel_size', kernel_size, [int,]) ParamValidator.check_type('stride', stride, [int,]) self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME']) - if not isinstance(kernel_size, int): - ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) - raise ValueError("kernel_size should be 1 int number but got {}". - format(kernel_size)) - if not isinstance(stride, int): - ParamValidator.check_integer("stride", stride, 1, Rel.GE) - raise ValueError("stride should be 1 int number but got {}".format(stride)) + ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE) + ParamValidator.check_integer("stride", stride, 1, Rel.GE) self.kernel_size = (1, kernel_size) self.stride = (1, stride) self.avg_pool = P.AvgPool(ksize=self.kernel_size, From b13e7bc31ab2945b90fd014f4c8d7ce46f210991 Mon Sep 17 00:00:00 2001 From: Junhan Hu Date: Mon, 20 Apr 2020 16:24:19 -0400 Subject: [PATCH 070/142] Add python multiprocessing support for Mindspore.dataset --- mindspore/dataset/engine/datasets.py | 97 +++++++++++++++++++++- mindspore/dataset/engine/iterators.py | 4 + tests/ut/python/dataset/test_pyfunc.py | 106 +++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 3 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 7c4857a580..13a643d3f3 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -24,6 +24,7 @@ import math import os import random import uuid +import multiprocessing from enum import Enum from importlib import import_module @@ -231,7 +232,7 @@ class Dataset: @check_map def map(self, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None): + num_parallel_workers=None, python_multiprocessing=False): """ Applies each operation in operations to this dataset. @@ -270,6 +271,8 @@ class Dataset: same). num_parallel_workers (int, optional): Number of threads used to process the dataset in parallel (default=None, the value from the config will be used). + python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This + option could be beneficial if the python operation is computational heavy (default=False). Returns: MapDataset, dataset after mapping operation. @@ -383,7 +386,8 @@ class Dataset: >>> columns_order = ["mod7", "mod3", "col1"] >>> ds_mapped = ds_pyfunc.map(input_columns, operations, output_columns, columns_order) """ - return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers) + return MapDataset(self, input_columns, operations, output_columns, columns_order, num_parallel_workers, + python_multiprocessing) @check_repeat def repeat(self, count=None): @@ -1041,6 +1045,55 @@ class ShuffleDataset(DatasetOp): return args +# Pyfunc collection for multiprocess pyfunc +# This global variable will only be used within subprocesses +_GLOBAL_PYFUNC_LIST = [] + + +# Pyfunc worker init function +# Python multiprocessing library forbid sending lambda function through pipe. +# This init function allow us to add all python function to a global collection and then fork afterwards. +def _pyfunc_worker_init(pyfunc_list): + global _GLOBAL_PYFUNC_LIST + _GLOBAL_PYFUNC_LIST = pyfunc_list + + +# Pyfunc worker execution function +# All exceptions will be raised to main processes +def _pyfunc_worker_exec(index, *args): + try: + return _GLOBAL_PYFUNC_LIST[index](*args) + except KeyboardInterrupt: + raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + + +# PythonCallable wrapper for multiprocess pyfunc +class _PythonCallable: + """ + Internal python function wrapper for multiprocessing pyfunc + """ + def __init__(self, py_callable, idx, pool=None): + # Original python callable from user. + self.py_callable = py_callable + # Process pool created for current iterator. + self.pool = pool + # Python callable index for subprocess _GLOBAL_PYFUNC_LIST + self.idx = idx + + def __call__(self, *args): + if self.pool is not None: + try: + # This call will send the tensors along with Python callable index to the process pool. + # Block, yield GIL. Current thread will reacquire GIL once result is returned. + return self.pool.apply(_pyfunc_worker_exec, [self.idx, *args]) + except KeyboardInterrupt: + self.pool.terminate() + self.pool.join() + raise Exception("Multiprocess MapOp worker receives KeyboardInterrupt") + # Invoke original python callable in master process in case the pool is gone. + return self.py_callable(*args) + + class MapDataset(DatasetOp): """ The result of applying Map operator to the input Dataset. @@ -1060,13 +1113,15 @@ class MapDataset(DatasetOp): The argument is mandatory if len(input_columns) != len(output_columns). num_parallel_workers (int, optional): Number of workers to process the Dataset in parallel (default=None). + python_multiprocessing (bool, optional): Parallelize python operations with multiple worker process. This + option could be beneficial if the python operation is computational heavy (default=False). Raises: ValueError: If len(input_columns) != len(output_columns) and columns_order is not specified. """ def __init__(self, input_dataset, input_columns=None, operations=None, output_columns=None, columns_order=None, - num_parallel_workers=None): + num_parallel_workers=None, python_multiprocessing=False): super().__init__(num_parallel_workers) self.input.append(input_dataset) if input_columns is not None and not isinstance(input_columns, list): @@ -1087,6 +1142,8 @@ class MapDataset(DatasetOp): input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs + self.python_multiprocessing = python_multiprocessing + self.process_pool = None def get_args(self): args = super().get_args() @@ -1104,6 +1161,40 @@ class MapDataset(DatasetOp): """ return self.input[0].get_dataset_size() + # Iterator bootstrap will be called on iterator construction. + # A deep copy of Dataset object is created prior of iterator_bootstrap. + # This method will create per iterator process pool and bind pyfunc execution to the pool. + def iterator_bootstrap(self): + """ + Per iterator bootstrap callback. + """ + if self.python_multiprocessing: + iter_specific_operations = [] + callable_list = [] + + # Pass #1, look for python callables and build list + for op in self.operations: + if callable(op): + callable_list.append(op) + + if callable_list: + # Construct pool with the callable list + # The callable list and _pyfunc_worker_init are used to pass lambda function in to subprocesses + self.process_pool = multiprocessing.Pool(processes=self.num_parallel_workers, + initializer=_pyfunc_worker_init, + initargs=(callable_list,)) + # Pass #2 + idx = 0 + for op in self.operations: + if callable(op): + # Wrap python callable into _PythonCallable + iter_specific_operations.append(_PythonCallable(op, idx, self.process_pool)) + idx += 1 + else: + # CPP ops remain the same + iter_specific_operations.append(op) + self.operations = iter_specific_operations + class RepeatDataset(DatasetOp): """ diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index a74d69b9c7..d70805ecc7 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -63,6 +63,10 @@ def _alter_node(node): return new_shuffle if isinstance(node, de.MapDataset): + if node.python_multiprocessing: + # Bootstrap can only be performed on a copy of the original dataset node. + # Bootstrap on original dataset node will make all iterators share the same process pool + node.iterator_bootstrap() if node.columns_order is not None: # Remove the connection between the parent's node to the current node because we are inserting a node. if node.output: diff --git a/tests/ut/python/dataset/test_pyfunc.py b/tests/ut/python/dataset/test_pyfunc.py index 4b0672a1f2..e7bdc48639 100644 --- a/tests/ut/python/dataset/test_pyfunc.py +++ b/tests/ut/python/dataset/test_pyfunc.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== import numpy as np +import pytest import mindspore.dataset as ds from mindspore import log as logger @@ -181,6 +182,106 @@ def test_case_6(): i = i + 4 +def test_case_7(): + """ + Test PyFunc + """ + logger.info("Test 1-1 PyFunc Multiprocess: lambda x : x + x") + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns="col0", output_columns="out", operations=(lambda x: x + x), + num_parallel_workers=4, python_multiprocessing = True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) + assert np.array_equal(item["out"], golden) + i = i + 4 + + +def test_case_8(): + """ + Test PyFunc + """ + logger.info("Test Multiprocess n-m PyFunc : lambda x, y : (x , x + 1, x + y)") + + col = ["col0", "col1"] + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns=col, output_columns=["out0", "out1", "out2"], num_parallel_workers=4, + operations=(lambda x, y: (x, x + y, x + y + 1)), columns_order=["out0", "out1", "out2"], + python_multiprocessing=True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i, i + 1], [i + 2, i + 3]]) + assert np.array_equal(item["out0"], golden) + golden = np.array([[i * 2, (i + 1) * 2], [(i + 2) * 2, (i + 3) * 2]]) + assert np.array_equal(item["out1"], golden) + golden = np.array([[i * 2 + 1, (i + 1) * 2 + 1], [(i + 2) * 2 + 1, (i + 3) * 2 + 1]]) + assert np.array_equal(item["out2"], golden) + i = i + 4 + + +def test_case_9(): + """ + Test PyFunc + """ + logger.info("Test multiple 1-1 PyFunc Multiprocess: lambda x : x + x") + + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + + data1 = data1.map(input_columns="col0", output_columns="out", operations=[(lambda x: x + x), (lambda x: x + 1), + (lambda x: x + 2)], + num_parallel_workers=4, python_multiprocessing=True) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + # In this test, the dataset is 2x2 sequential tensors + golden = np.array([[i * 2 + 3, (i + 1) * 2 + 3], [(i + 2) * 2 + 3, (i + 3) * 2 + 3]]) + assert np.array_equal(item["out"], golden) + i = i + 4 + + +def test_pyfunc_execption(): + logger.info("Test PyFunc Execption Throw: lambda x : raise Execption()") + + def pyfunc(x): + raise Exception("Pyfunc Throw") + + with pytest.raises(RuntimeError) as info: + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc, + num_parallel_workers=4) + for _ in data1: + pass + assert "Pyfunc Throw" in str(info.value) + + +def test_pyfunc_execption_multiprocess(): + logger.info("Test Multiprocess PyFunc Execption Throw: lambda x : raise Execption()") + + def pyfunc(x): + raise Exception("MP Pyfunc Throw") + + with pytest.raises(RuntimeError) as info: + # apply dataset operations + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + data1 = data1.map(input_columns="col0", output_columns="out", operations= pyfunc, + num_parallel_workers=4, python_multiprocessing = True) + for _ in data1: + pass + assert "MP Pyfunc Throw" in str(info.value) + + if __name__ == "__main__": test_case_0() test_case_1() @@ -189,3 +290,8 @@ if __name__ == "__main__": test_case_4() test_case_5() test_case_6() + test_case_7() + test_case_8() + test_case_9() + test_pyfunc_execption() + test_pyfunc_execption_multiprocess() From 78001ac9e6e6c4e2c8e89dd54ba24029703628e3 Mon Sep 17 00:00:00 2001 From: Junhan Hu Date: Mon, 20 Apr 2020 17:26:23 -0400 Subject: [PATCH 071/142] Add multiprocessing support for Mindspore.Dataset.GeneratorDataset --- mindspore/dataset/engine/datasets.py | 148 +++++++++++++++++++++- tests/ut/python/dataset/test_generator.py | 99 +++++++++++++++ 2 files changed, 245 insertions(+), 2 deletions(-) diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 3225ebc806..62c8e75ca9 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -25,6 +25,7 @@ import os import random import uuid import multiprocessing +import queue from enum import Enum from importlib import import_module @@ -2124,6 +2125,142 @@ def _cpp_sampler_fn(sampler, dataset): yield tuple([np.array(x) for x in val]) +def _cpp_sampler_fn_mp(sampler, dataset, num_worker): + """ + Multiprocessing generator function wrapper for mappable dataset with cpp sampler + """ + indices = sampler.get_indices() + return _sampler_fn_mp(indices, dataset, num_worker) + + +def _py_sampler_fn_mp(sampler, num_samples, dataset, num_worker): + """ + Multiprocessing generator function wrapper for mappable dataset with python sampler + """ + indices = _fetch_py_sampler_indices(sampler, num_samples) + return _sampler_fn_mp(indices, dataset, num_worker) + + +def _fetch_py_sampler_indices(sampler, num_samples): + """ + Indices fetcher for python sampler + """ + if num_samples is not None: + sampler_iter = iter(sampler) + ret = [] + for _ in range(num_samples): + try: + val = next(sampler_iter) + ret.append(val) + except StopIteration: + break + return ret + return [i for i in sampler] + + +def _fill_worker_indices(workers, indices, idx): + """ + Worker index queue filler, fill worker index queue in round robin order + """ + num_worker = len(workers) + while idx < len(indices): + try: + workers[idx % num_worker].put(indices[idx]) + idx += 1 + except queue.Full: + break + return idx + + +def _sampler_fn_mp(indices, dataset, num_worker): + """ + Multiprocessing generator function wrapper master process + """ + workers = [] + # Event for end of epoch + eoe = multiprocessing.Event() + + # Create workers + for _ in range(num_worker): + worker = _GeneratorWorker(dataset, eoe) + worker.daemon = True + workers.append(worker) + + # Fill initial index queues + idx_cursor = 0 + idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) + + # Start all workers + for w in workers: + w.start() + + # Fetch results + for i in range(len(indices)): + # Fetch result and put index + try: + result = workers[i % num_worker].get() + except queue.Empty: + raise Exception("Generator worker process timeout") + except KeyboardInterrupt: + for w in workers: + w.terminate() + w.join() + raise Exception("Generator worker receives KeyboardInterrupt") + if idx_cursor < len(indices): + idx_cursor = _fill_worker_indices(workers, indices, idx_cursor) + # Set eoe event once all indices are sent + if idx_cursor == len(indices) and not eoe.is_set(): + eoe.set() + yield tuple([np.array(x) for x in result]) + + +def _generator_worker_loop(dataset, idx_queue, result_queue, eoe): + """ + Multiprocessing generator worker process loop + """ + while True: + # Fetch index, block + try: + idx = idx_queue.get() + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + if idx is None: + # When the queue is out of scope from master process, a None item can be fetched from the queue. + # Upon receiving None, worker process should check if EOE is set. + assert eoe.is_set(), "" + return + # Fetch data, any exception from __getitem__ will terminate worker and timeout master process + result = dataset[idx] + # Send data, block + try: + result_queue.put(result) + except KeyboardInterrupt: + raise Exception("Generator worker receives KeyboardInterrupt") + del result, idx + + +class _GeneratorWorker(multiprocessing.Process): + """ + Worker process for multiprocess Generator + """ + def __init__(self, dataset, eoe): + self.idx_queue = multiprocessing.Queue(16) + self.res_queue = multiprocessing.Queue(16) + super().__init__(target=_generator_worker_loop, args=(dataset, self.idx_queue, self.res_queue, eoe)) + + def put(self, item): + """ + Put function for worker index queue. Never block. Raise queue.Full on failure. + """ + self.idx_queue.put_nowait(item) + + def get(self): + """ + Get function for worker result queue. Block with timeout. + """ + return self.res_queue.get(timeout=5) + + class GeneratorDataset(SourceDataset): """ A source dataset that generate data from python by invoking python data source each epoch. @@ -2171,6 +2308,7 @@ class GeneratorDataset(SourceDataset): If the schema is not provided, the meta data from column_names and column_types is considered the schema. num_samples (int, optional): The number of samples to be included in the dataset (default=None, all images). + num_parallel_workers (int, optional): Number of subprocesses used to fetch the dataset in parallel (default=1). shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. (default=None, expected order behavior shown in the table). sampler (Sampler/Iterable, optional): Object used to choose samples from the dataset. Random accessible input is @@ -2229,9 +2367,15 @@ class GeneratorDataset(SourceDataset): sampler_instance.set_num_rows(len(source)) sampler_instance.set_num_samples(num_samples) sampler_instance.initialize() - self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) + if num_parallel_workers > 1: + self.source = (lambda: _cpp_sampler_fn_mp(sampler_instance, source, num_parallel_workers)) + else: + self.source = (lambda: _cpp_sampler_fn(sampler_instance, source)) else: - self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) + if num_parallel_workers > 1: + self.source = (lambda: _py_sampler_fn_mp(self.sampler, num_samples, source, num_parallel_workers)) + else: + self.source = (lambda: _py_sampler_fn(self.sampler, num_samples, source)) else: try: iter(source) diff --git a/tests/ut/python/dataset/test_generator.py b/tests/ut/python/dataset/test_generator.py index c224c5a2ea..4daf952eba 100644 --- a/tests/ut/python/dataset/test_generator.py +++ b/tests/ut/python/dataset/test_generator.py @@ -391,6 +391,80 @@ def test_case_13(): i = i + 1 +def test_case_14(): + """ + Test 1D Generator MP + CPP sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=ds.SequentialSampler(), num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_case_15(): + """ + Test 1D Generator MP + Python sampler + """ + logger.info("Test 1D Generator MP : 0 - 63") + + sampler = [x for x in range(256)] + source = [(np.array([x]),) for x in range(256)] + ds1 = ds.GeneratorDataset(source, ["data"], sampler=sampler, num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(data["data"], golden) + i = i + 1 + if i == 256: + i = 0 + + +def test_case_16(): + """ + Test multi column generator Mp + CPP sampler + """ + logger.info("Test multi column generator") + + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=ds.SequentialSampler()) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(item["col0"], golden) + golden = np.array([i + 1]) + assert np.array_equal(item["col1"], golden) + i = i + 1 + + +def test_case_17(): + """ + Test multi column generator Mp + Python sampler + """ + logger.info("Test multi column generator") + + sampler = [x for x in range(256)] + source = [(np.array([x]), np.array([x + 1])) for x in range(256)] + # apply dataset operations + data1 = ds.GeneratorDataset(source, ["col0", "col1"], sampler=sampler) + + i = 0 + for item in data1.create_dict_iterator(): # each data is a dictionary + golden = np.array([i]) + assert np.array_equal(item["col0"], golden) + golden = np.array([i + 1]) + assert np.array_equal(item["col1"], golden) + i = i + 1 + + def test_case_error_1(): def generator_np(): for i in range(64): @@ -506,6 +580,25 @@ def test_num_samples_underflow(): count = count + 1 assert count == 64 +def manual_test_keyborad_interrupt(): + """ + Test keyborad_interrupt + """ + logger.info("Test 1D Generator MP : 0 - 63") + + class MyDS(): + def __getitem__(self, item): + while True: + pass + + def __len__(self): + return 1024 + + ds1 = ds.GeneratorDataset(MyDS(), ["data"], num_parallel_workers=4).repeat(2) + i = 0 + for data in ds1.create_dict_iterator(): # each data is a dictionary + pass + if __name__ == "__main__": test_case_0() @@ -522,6 +615,10 @@ if __name__ == "__main__": test_case_11() test_case_12() test_case_13() + test_case_14() + test_case_15() + test_case_16() + test_case_17() test_case_error_1() test_case_error_2() test_case_error_3() @@ -529,3 +626,5 @@ if __name__ == "__main__": test_sequential_sampler() test_distributed_sampler() test_random_sampler() + + From 5fcd3f01a63702005404416e71ec3709a5bfc4d0 Mon Sep 17 00:00:00 2001 From: Adel Shafiei Date: Mon, 20 Apr 2020 15:46:59 -0400 Subject: [PATCH 072/142] Added C++ UniformAugOp support --- .../ccsrc/dataset/api/python_bindings.cc | 5 ++ .../dataset/kernels/image/CMakeLists.txt | 2 + .../dataset/kernels/image/uniform_aug_op.cc | 87 +++++++++++++++++++ .../dataset/kernels/image/uniform_aug_op.h | 60 +++++++++++++ .../dataset/transforms/vision/c_transforms.py | 18 +++- .../dataset/transforms/vision/validators.py | 33 +++++++ 6 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc create mode 100644 mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 214ce4c153..655fad7d55 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -40,6 +40,7 @@ #include "dataset/kernels/image/rescale_op.h" #include "dataset/kernels/image/resize_bilinear_op.h" #include "dataset/kernels/image/resize_op.h" +#include "dataset/kernels/image/uniform_aug_op.h" #include "dataset/kernels/data/type_cast_op.h" #include "dataset/engine/datasetops/source/cifar_op.h" #include "dataset/engine/datasetops/source/image_folder_op.h" @@ -264,6 +265,10 @@ void bindTensorOps1(py::module *m) { .def(py::init(), py::arg("targetHeight"), py::arg("targetWidth") = ResizeOp::kDefWidth, py::arg("interpolation") = ResizeOp::kDefInterpolation); + (void)py::class_>( + *m, "UniformAugOp", "Tensor operation to apply random augmentation(s).") + .def(py::init(), py::arg("operations"), py::arg("NumOps") = UniformAugOp::kDefNumOps); + (void)py::class_>( *m, "ResizeBilinearOp", "Tensor operation to resize an image using " diff --git a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt index 23a26d5214..33e681337c 100644 --- a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt @@ -19,6 +19,7 @@ if (WIN32) rescale_op.cc resize_bilinear_op.cc resize_op.cc + uniform_aug_op.cc ) else() add_library(kernels-image OBJECT @@ -42,5 +43,6 @@ else() rescale_op.cc resize_bilinear_op.cc resize_op.cc + uniform_aug_op.cc ) endif() diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc new file mode 100644 index 0000000000..5725c10908 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.cc @@ -0,0 +1,87 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#include "dataset/kernels/image/uniform_aug_op.h" +#include "dataset/kernels/py_func_op.h" +#include "dataset/util/random.h" + +namespace mindspore { +namespace dataset { +const int UniformAugOp::kDefNumOps = 2; + +UniformAugOp::UniformAugOp(py::list op_list, int32_t num_ops) : num_ops_(num_ops) { + std::shared_ptr tensor_op; + // iterate over the op list, cast them to TensorOp and add them to tensor_op_list_ + for (auto op : op_list) { + if (py::isinstance(op)) { + // python op + tensor_op = std::make_shared(op.cast()); + } else if (py::isinstance(op)) { + // C++ op + tensor_op = op.cast>(); + } + tensor_op_list_.insert(tensor_op_list_.begin(), tensor_op); + } + + rnd_.seed(GetSeed()); +} +// compute method to apply uniformly random selected augmentations from a list +Status UniformAugOp::Compute(const std::vector> &input, + std::vector> *output) { + IO_CHECK_VECTOR(input, output); + + // variables to generate random number to select ops from the list + std::vector random_indexes; + + // variables to copy the result to output if it is not already + std::vector> even_out; + std::vector> *even_out_ptr = &even_out; + int count = 1; + + // select random indexes for candidates to be applied + for (int i = 0; i < num_ops_; ++i) { + random_indexes.insert(random_indexes.end(), + std::uniform_int_distribution(0, tensor_op_list_.size() - 1)(rnd_)); + } + + for (auto it = random_indexes.begin(); it != random_indexes.end(); ++it) { + // Do NOT apply the op, if second random generator returned zero + if (std::uniform_int_distribution(0, 1)(rnd_)) { + continue; + } + std::shared_ptr tensor_op = tensor_op_list_[*it]; + + // apply python/C++ op + if (count == 1) { + (*tensor_op).Compute(input, output); + } else if (count % 2 == 0) { + (*tensor_op).Compute(*output, even_out_ptr); + } else { + (*tensor_op).Compute(even_out, output); + } + count++; + } + + // copy the result to output if it is not in output + if (count == 1) { + *output = input; + } else if ((count % 2 == 1)) { + (*output).swap(even_out); + } + + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h new file mode 100644 index 0000000000..336bc8c859 --- /dev/null +++ b/mindspore/ccsrc/dataset/kernels/image/uniform_aug_op.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + + * http://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. +*/ +#ifndef DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ +#define DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ + +#include +#include +#include +#include + +#include "dataset/core/tensor.h" +#include "dataset/kernels/tensor_op.h" +#include "dataset/util/status.h" +#include "dataset/kernels/py_func_op.h" + +#include "pybind11/stl.h" + +namespace mindspore { +namespace dataset { +class UniformAugOp : public TensorOp { + public: + // Default number of Operations to be applied + static const int kDefNumOps; + + // Constructor for UniformAugOp + // @param list op_list: list of candidate python operations + // @param list num_ops: number of augemtation operations to applied + UniformAugOp(py::list op_list, int32_t num_ops); + + ~UniformAugOp() override = default; + + void Print(std::ostream &out) const override { out << "UniformAugOp:: number of ops " << num_ops_; } + + // Overrides the base class compute function + // @return Status - The error code return + Status Compute(const std::vector> &input, + std::vector> *output) override; + + private: + int32_t num_ops_; + std::vector> tensor_op_list_; + std::mt19937 rnd_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_KERNELS_IMAGE_UNIFORM_AUG_OP_H_ diff --git a/mindspore/dataset/transforms/vision/c_transforms.py b/mindspore/dataset/transforms/vision/c_transforms.py index 171eb846a8..07011b1d53 100644 --- a/mindspore/dataset/transforms/vision/c_transforms.py +++ b/mindspore/dataset/transforms/vision/c_transforms.py @@ -45,7 +45,7 @@ import mindspore._c_dataengine as cde from .utils import Inter, Border from .validators import check_prob, check_crop, check_resize_interpolation, check_random_resize_crop, \ check_normalize_c, check_random_crop, check_random_color_adjust, check_random_rotation, \ - check_resize, check_rescale, check_pad, check_cutout + check_resize, check_rescale, check_pad, check_cutout, check_uniform_augmentation DE_C_INTER_MODE = {Inter.NEAREST: cde.InterpolationMode.DE_INTER_NEAREST_NEIGHBOUR, Inter.LINEAR: cde.InterpolationMode.DE_INTER_LINEAR, @@ -447,3 +447,19 @@ class Pad(cde.PadOp): fill_value = tuple([fill_value] * 3) padding_mode = DE_C_BORDER_TYPE[padding_mode] super().__init__(*padding, padding_mode, *fill_value) + + +class UniformAugment(cde.UniformAugOp): + """ + Tensor operation to perform randomly selected augmentation + + Args: + operations: list of python operations. + NumOps (int): number of OPs to be selected and applied. + """ + + @check_uniform_augmentation + def __init__(self, operations, num_ops=2): + self.operations = operations + self.num_ops = num_ops + super().__init__(operations, num_ops) diff --git a/mindspore/dataset/transforms/vision/validators.py b/mindspore/dataset/transforms/vision/validators.py index ef4b879f8c..713d9c5714 100644 --- a/mindspore/dataset/transforms/vision/validators.py +++ b/mindspore/dataset/transforms/vision/validators.py @@ -812,3 +812,36 @@ def check_rescale(method): return method(self, **kwargs) return new_method + + +def check_uniform_augmentation(method): + """Wrapper method to check the parameters of UniformAugmentation.""" + + @wraps(method) + def new_method(self, *args, **kwargs): + operations, num_ops = (list(args) + 2 * [None])[:2] + if "operations" in kwargs: + operations = kwargs.get("operations") + else: + raise ValueError("operations list required") + if "num_ops" in kwargs: + num_ops = kwargs.get("num_ops") + else: + num_ops = 2 + + if num_ops <= 0: + raise ValueError("num_ops should be greater than zero") + if num_ops > len(operations): + raise ValueError("num_ops is greater than operations list size") + if not isinstance(operations, list): + raise ValueError("operations is not a python list") + for op in operations: + if not callable(op): + raise ValueError("non-callable op in operations list") + + kwargs["num_ops"] = num_ops + kwargs["operations"] = operations + + return method(self, **kwargs) + + return new_method From cd94518769ba56638a02ccfa8fe1a56ae399dd8d Mon Sep 17 00:00:00 2001 From: eric Date: Thu, 2 Apr 2020 18:58:41 -0400 Subject: [PATCH 073/142] X# This is a combination of 2 commits. Initial commit for dataset op python Added signature to barrier Adde compiling barrier code Rebasing, fixed new compile errors Final fix for make_unique Added pybind API for barrier Fixed pyfunc invocation python interface - sync_wait !1 sync_wait python interface * python interface - sync_wait fix test update test update test Added new test case add test case test for shuffle + batch Added two-sync test case Restrited that no shuffle after sync Added sync to pipeline info block first databuffer as well Intelligently get batch size Fix default case Lock Pair shares among all iterators Added fix for empty character Fixed up test case formatting Fix end of epoch in sync_wait Fixing CI --- mindspore/ccsrc/dataset/api/de_pipeline.cc | 25 ++ mindspore/ccsrc/dataset/api/de_pipeline.h | 3 + .../ccsrc/dataset/api/python_bindings.cc | 1 + mindspore/ccsrc/dataset/core/client.h | 1 + .../dataset/engine/datasetops/CMakeLists.txt | 1 + .../dataset/engine/datasetops/barrier_op.cc | 235 ++++++++++++++++++ .../dataset/engine/datasetops/barrier_op.h | 172 +++++++++++++ .../ccsrc/dataset/engine/datasetops/zip_op.h | 18 +- mindspore/dataset/engine/datasets.py | 184 +++++++++++++- mindspore/dataset/engine/iterators.py | 2 + mindspore/dataset/engine/validators.py | 16 ++ tests/ut/python/dataset/test_config.py | 38 +++ tests/ut/python/dataset/test_sync_wait.py | 182 ++++++++++++++ 13 files changed, 868 insertions(+), 10 deletions(-) create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc create mode 100644 mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h create mode 100644 tests/ut/python/dataset/test_sync_wait.py diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.cc b/mindspore/ccsrc/dataset/api/de_pipeline.cc index a02d995147..c3dfeafe48 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.cc +++ b/mindspore/ccsrc/dataset/api/de_pipeline.cc @@ -48,6 +48,7 @@ static std::unordered_map g_parse_op_func_ = {{kStorage, &D {kMap, &DEPipeline::ParseMapOp}, {kFilter, &DEPipeline::ParseFilterOp}, {kBatch, &DEPipeline::ParseBatchOp}, + {kBarrier, &DEPipeline::ParseBarrierOp}, {kRepeat, &DEPipeline::ParseRepeatOp}, {kSkip, &DEPipeline::ParseSkipOp}, {kZip, &DEPipeline::ParseZipOp}, @@ -627,6 +628,30 @@ Status DEPipeline::ParseBatchOp(const py::dict &args, std::shared_ptr return Status::OK(); } +Status DEPipeline::ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr) { + std::shared_ptr builder = std::make_shared(); + // Right now barrier should only take num_rows_per_buffer = 1 + // The reason for this is because having it otherwise can lead to blocking issues + // See barrier_op.h for more details + (void)builder->SetRowsPerBuffer(1); + for (auto arg : args) { + std::string key = py::str(arg.first); + py::handle value = arg.second; + if (!value.is_none()) { + if (key == "condition_name") { + (void)builder->SetConditionName(ToString(value)); + } else if (key == "condition_func") { + (void)builder->SetConditionFunc(value.cast()); + } + } + } + + std::shared_ptr op; + RETURN_IF_NOT_OK(builder->Build(&op)); + *ptr = op; + return Status::OK(); +} + Status DEPipeline::ParseDeviceQueueOp(const py::dict &args, std::shared_ptr *ptr) { int32_t prefetch_size = 0; if (args.contains("prefetch_size")) { diff --git a/mindspore/ccsrc/dataset/api/de_pipeline.h b/mindspore/ccsrc/dataset/api/de_pipeline.h index 25919afe58..7f9c6c459a 100644 --- a/mindspore/ccsrc/dataset/api/de_pipeline.h +++ b/mindspore/ccsrc/dataset/api/de_pipeline.h @@ -40,6 +40,7 @@ enum OpName { kShuffle, kMindrecord, kBatch, + kBarrier, kCache, kRepeat, kSkip, @@ -115,6 +116,8 @@ class DEPipeline { Status ParseBatchOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseBarrierOp(const py::dict &args, std::shared_ptr *ptr); + Status ParseGeneratorOp(const py::dict &args, std::shared_ptr *ptr); Status ParseRenameOp(const py::dict &args, std::shared_ptr *ptr); diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 9865396a7d..2b8ce4e896 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -476,6 +476,7 @@ PYBIND11_MODULE(_c_dataengine, m) { .value("STORAGE", OpName::kStorage) .value("SHUFFLE", OpName::kShuffle) .value("BATCH", OpName::kBatch) + .value("BARRIER", OpName::kBarrier) .value("MINDRECORD", OpName::kMindrecord) .value("CACHE", OpName::kCache) .value("REPEAT", OpName::kRepeat) diff --git a/mindspore/ccsrc/dataset/core/client.h b/mindspore/ccsrc/dataset/core/client.h index 15064dee6b..40de887aea 100644 --- a/mindspore/ccsrc/dataset/core/client.h +++ b/mindspore/ccsrc/dataset/core/client.h @@ -25,6 +25,7 @@ #include "dataset/core/tensor_shape.h" #include "dataset/engine/data_schema.h" #include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/barrier_op.h" #include "dataset/engine/datasetops/batch_op.h" #include "dataset/engine/datasetops/dataset_op.h" #include "dataset/engine/datasetops/device_queue_op.h" diff --git a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt index 7de62d9d11..9e8272d513 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/engine/datasetops/CMakeLists.txt @@ -4,6 +4,7 @@ add_library(engine-datasetops OBJECT dataset_op.cc parallel_op.cc pipeline_op.cc + barrier_op.cc batch_op.cc device_queue_op.cc map_op.cc diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc new file mode 100644 index 0000000000..b0ea7dbd07 --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.cc @@ -0,0 +1,235 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "dataset/engine/datasetops/barrier_op.h" +#include +#include "dataset/core/constants.h" +#include "dataset/engine/data_buffer.h" +#include "dataset/engine/db_connector.h" +#include "dataset/core/config_manager.h" +#include "dataset/core/global_context.h" +#include "utils/log_adapter.h" + +namespace mindspore { +namespace dataset { +BarrierOp::Builder::Builder() { + // Some arguments to the BarrierOp constructor have a default argument that is taken + // from the client config. + // The user may choose to change these values for the construction of the BarrierOp by + // using the various builder set methods. + + std::shared_ptr cfg = GlobalContext::config_manager(); + builder_rows_per_buffer_ = cfg->rows_per_buffer(); + builder_op_connector_size_ = cfg->op_connector_size(); +} + +Status BarrierOp::Builder::SanityCheck() const { return Status::OK(); } + +Status BarrierOp::Builder::Build(std::shared_ptr *ptr) { + RETURN_IF_NOT_OK(SanityCheck()); + *ptr = std::make_shared(builder_rows_per_buffer_, builder_op_connector_size_, builder_condition_name_, + builder_condition_func_); + return Status::OK(); +} + +// Construct BarrierOp here, local variables initialized in operator due to tree construction restrictions +BarrierOp::BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func) + : PipelineOp(op_connector_size), + rows_per_buffer_(rows_per_buffer), + buffer_id_(0), + clean_up_(false), + eof_(false), + condition_name_(condition_name), + condition_function_(condition_func) {} + +// destructor +BarrierOp::~BarrierOp() {} + +// Entry point for Barrier, called by launch() +Status BarrierOp::operator()() { + // The children_num_ parameter needs to be put here + // Synchronize with TaskManager once the thread is created. + TaskManager::FindMe()->Post(); + + // create child iterator, right now this barrier is a pipeline operator + int32_t worker_id = 0; + int32_t child_idx = 0; + child_iterator_ = std::make_unique(this, worker_id, child_idx); + + // Loop until eof is true + while (!eof_) { + // Create new table to put the new tensor rows + std::unique_ptr curr_table = std::make_unique(); + RETURN_IF_NOT_OK(prepare(curr_table.get())); + + // If an eof got picked up during the above prepare, then we're done + if (eof_) { + break; + } + + // we have to output new buffer with possibly different buffer size, possibly one row + while (!clean_up_) { + // 1. If a previous loop iteration sent the current table out, then create a new one. + + if (curr_table == nullptr) { + curr_table = std::make_unique(); + } + + // 2 fill the table. Note: clean_up mode might get turned on if epoch is finished + RETURN_IF_NOT_OK(fillBuffer(curr_table.get())); + + // 3 create and update buffer and send it to the out connector + if (!curr_table->empty()) { + std::unique_ptr curr_buffer = std::make_unique(buffer_id_, DataBuffer::kDeBFlagNone); + curr_buffer->set_tensor_table(std::move(curr_table)); + curr_buffer->set_column_name_map(col_name_id_map_); + MS_LOG(DEBUG) << "Barrier operator finished one buffer, pushing, rows " << curr_buffer->NumRows() << ", cols " + << curr_buffer->NumCols() << ", map " << col_name_id_map_.size() << "."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(curr_buffer))); + buffer_id_++; + } + } + + // 4 handle drain state. + if (clean_up_) { + MS_LOG(DEBUG) << "Barrier operator sending epoch ending signal."; + // Send the eoe up. + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOE)))); + } + } + // 5 handle eof + // propagate eof here. + MS_LOG(INFO) << "Barrier operator got EOF, propagating."; + RETURN_IF_NOT_OK(out_connector_->Add(0, std::move(std::make_unique(0, DataBuffer::kDeBFlagEOF)))); + return Status::OK(); +} + +// Handles preprocessing of the main loop, used when starting new epoch +Status BarrierOp::prepare(TensorQTable *const table) { + MS_LOG(DEBUG) << "Barrier operator prepares for new epoch."; + clean_up_ = false; + buffer_id_ = 0; + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp prepare phase requires a tensor table."); + } + // fill initial row + TensorRow new_row = {}; + // use iterator to get next row and invoke pyfunc wait + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + + // If the first row fetching resulted in eof, then we are done. + if (eof_) { + return Status::OK(); + } + if (new_row.empty()) { + // This epoch is empty + return Status::OK(); + } + // Pack this first row into our tensor table + // first row we also have to check if we should block + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + // At this point we have 1 row produced, we take the old column map id and use it in the new table + // Initializing col_name_id_map_ from the first data buffer. + col_name_id_map_ = child_iterator_->col_name_id_map(); + // the update code below shouldn't do anything bad if the column name already exists. + return Status::OK(); +} + +// fillBuffer always expects a new table to fill +Status BarrierOp::fillBuffer(TensorQTable *const table) { + if (table == nullptr) { + return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "BarrierOp fillBuffer null table pointer."); + } + TensorRow new_row = {}; + while (table->size() < static_cast(rows_per_buffer_)) { + RETURN_IF_NOT_OK(getNextTensorRow(&new_row)); + // Early exit the loop if we got empty row from any of our child iterations + if (new_row.empty()) { + return Status::OK(); + } + // else we got a row so pack it into the tensor table. + RETURN_IF_NOT_OK(blockCond()); + + table->push_back(std::move(new_row)); + } + return Status::OK(); +} + +// function executes a py_func and blocks until condition becomes true. +Status BarrierOp::blockCond() { + { + py::gil_scoped_acquire gil_acquire; + if (Py_IsInitialized() == 0) { + return Status(StatusCode::kPythonInterpreterFailure, "Python Interpreter is finalized"); + } + // we have condition name, however the flexibility is in python today + try { + // Invoke python function + py::object ret_py_obj = condition_function_(); + // Process the return value + if (!py::isinstance(ret_py_obj)) { + return Status(StatusCode::kPyFuncException, "Condition wait function should return true/false"); + } + } catch (const py::error_already_set &e) { + return Status(StatusCode::kPyFuncException, e.what()); + } + } + return Status::OK(); +} + +// fetches next Barrier buffer row +Status BarrierOp::getNextTensorRow(TensorRow *new_row) { + // iterate over all iterators and generate a row + RETURN_IF_NOT_OK((child_iterator_)->FetchNextTensorRow(new_row)); + // add each new row to iterator, check if row is empty, if row from iterator is empty return empty row + if (new_row->empty()) { + // If we did not get a row from any of the children, then it's the end of an epoch and we can move + // to drain state. + MS_LOG(INFO) << "Barrier operator child iterator produced empty row."; + clean_up_ = true; + // If we picked up an eof here, then we are completely done. + if ((child_iterator_)->eof_handled()) { + MS_LOG(INFO) << "Barrier operator iterator got EOF."; + eof_ = true; + } + return Status::OK(); + } + return Status::OK(); +} + +// A function that prints info about the Operator +void BarrierOp::Print(std::ostream &out, bool show_all) const { + // Call base class printer first + PipelineOp::Print(out, show_all); + out << "\nBarrierOp:\n" + << "\nCondition " << condition_name_ << "\n\n"; +} + +// overwrite function and handle eof +Status BarrierOp::EofReceived(int32_t) { + MS_LOG(DEBUG) << "Barrier operator EOF received, do nothing now."; + return Status::OK(); +} + +// overwrite function and handle eoe +Status BarrierOp::EoeReceived(int32_t) { + state_ = OpState::kDeOpIdle; + return Status::OK(); +} +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h new file mode 100644 index 0000000000..8be55fba7e --- /dev/null +++ b/mindspore/ccsrc/dataset/engine/datasetops/barrier_op.h @@ -0,0 +1,172 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ +#define DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ + +#include +#include +#include +#include +#include +#include "dataset/core/tensor.h" +#include "dataset/engine/dataset_iterator.h" +#include "dataset/engine/datasetops/pipeline_op.h" +#include "dataset/kernels/tensor_op.h" + +namespace mindspore { +namespace dataset { +// Forward declare +class DataBuffer; +class ExecutionTree; + +// BarrierOp class implements the Barrier operator. It will block sending of rows until a signal has +// been received. This signal is given from python layer. The current barrier design respects the +// rows per buffer design and will only output a buffer with rows once it has received rows per buffer +// signals from python. + +class BarrierOp : public PipelineOp { + public: + // The nested builder class inside of the BarrierOp is used to help manage all of + // the arguments for constructing it. Use the builder by setting each argument + // with the provided set methods, and then finally call the build method to execute + // the actual construction. + + class Builder { + public: + // Builder constructor. Creates the builder object. + // @note No default args + // @return This is a constructor. + Builder(); + + // Default destructor + ~Builder() = default; + + // Setter method. + // @return Builder setter method returns reference to the builder. + Builder &SetRowsPerBuffer(int32_t rows_per_buffer) { + builder_rows_per_buffer_ = rows_per_buffer; + return *this; + } + + // Setter method. + // @param int32_t op_connector_size + // @return Builder setter method returns reference to the builder. + Builder &SetOpConnectorSize(int32_t op_connector_size) { + builder_op_connector_size_ = op_connector_size; + return *this; + } + + // Setter method. + // @param const std::string & condition_name + // @return Builder setter method returns reference to the builder. + Builder &SetConditionName(const std::string &condition_name) { + builder_condition_name_ = condition_name; + return *this; + } + + // Setter method. + // @param py::function condition_func - blocking condition function + // @return Builder setter method returns reference to the builder. + Builder &SetConditionFunc(py::function condition_func) { + builder_condition_func_ = condition_func; + return *this; + } + + // The builder "build" method creates the BarrierOp dataset Operator. + // @return shared_ptr to the new BarrierOp object + Status Build(std::shared_ptr *); + + private: + int32_t builder_rows_per_buffer_; + int32_t builder_op_connector_size_; + std::string builder_condition_name_; + py::function builder_condition_func_; + + Status SanityCheck() const; + }; + + // Constructor for BarrierOp + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size + // @param condition_name - the condition name associated with this operator + // @param condition_func - the blocking condition check per row + // @note - currently rows_per_buffer should = 1 for barrier. + // The reason for this is having other values would complicate how the pipeline behaves with other operators + // One example of such case is having batch after barrier. Batch would be waiting for data and having + // rows per buffer in this case can result in hanging + BarrierOp(int32_t rows_per_buffer, int32_t op_connector_size, const std::string &condition_name, + py::function condition_func); + + // Destructor + ~BarrierOp(); + + Status EofReceived(int32_t) override; + + Status EoeReceived(int32_t) override; + + // Print function for Barrier + // @param out - output stream to print to + // @param show_all - if it should print everything + void Print(std::ostream &out, bool show_all) const override; + + // Provide stream operator for displaying it + friend std::ostream &operator<<(std::ostream &out, const BarrierOp &bo) { + bo.Print(out, false); + return out; + } + + // Class functor operator () override. + // All dataset ops operate by launching a thread (see ExecutionTree). This class functor will + // provide the master loop that drives the logic for performing the work + // @return Status - The error code return + Status operator()() override; + + // Handles preprocessing of the main loop, used when starting new epoch + // @param table - a table of tensors to be moved into a buffer + Status prepare(TensorQTable *const table); + + // This function calls takes a table repeatedly adds rows to it. + // @param table - a table of tensors to be moved into a buffer + Status fillBuffer(TensorQTable *const table); + + // Gets next tensor row and sets control signals + Status getNextTensorRow(TensorRow *new_row); + + // This function runs the wait function on condition + Status blockCond(); + + private: + // clean up variable to return imcomplete buffer + bool clean_up_; + // end of file state, we stop reading data and shut down + bool eof_; + // rows per buffer + int32_t rows_per_buffer_; + // buffer_id + int32_t buffer_id_; + // local variable to keep track of the buffer information + std::unordered_map col_name_id_map_; + // iterator to pull new rows, we only have one child + std::unique_ptr child_iterator_; + // condition name, to support multiple barriers + std::string condition_name_; + // Function pointer of blocking function + py::function condition_function_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // DATASET_ENGINE_DATASETOPS_BARRIER_OP_H_ diff --git a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h index f14ecba733..04d8ab0121 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/zip_op.h @@ -34,7 +34,7 @@ class DataBuffer; class ZipOp : public PipelineOp { public: - // The nested builder class inside of the BatchOp is used to help manage all of + // The nested builder class inside of the ZipOp is used to help manage all of // the arguments for constructing it. Use the builder by setting each argument // with the provided set methods, and then finally call the build method to execute // the actual construction. @@ -76,8 +76,8 @@ class ZipOp : public PipelineOp { }; // Constructor for ZipOp - // @param rows_per_buffer number of rows in output buffer - // @param op_connector_size connector + // @param rows_per_buffer - number of rows in output buffer + // @param op_connector_size - connector size ZipOp(int32_t rows_per_buffer, int32_t op_connector_size); // Destructor @@ -88,8 +88,8 @@ class ZipOp : public PipelineOp { Status EoeReceived(int32_t) override; // Print function for Zip - // @param out output stream to print to - // @param show_all if it should print everything + // @param out - output stream to print to + // @param show_all - if it should print everything void Print(std::ostream &out, bool show_all) const override; // Provide stream operator for displaying it @@ -113,14 +113,14 @@ class ZipOp : public PipelineOp { Status fillBuffer(TensorQTable *const table); // Special handle case where an empty row has been received from child iterator - // @note we need to drain eoe signals from all children connectors. - // @details when this function is called, then we encountered eoe at child iterator + // @note - we need to drain eoe signals from all children connectors. + // @details - when this function is called, then we encountered eoe at child iterator // we have to drain rows from other child iterators until we hit eoe from all other child iterators Status drainPipeline(); // Merges 1 row from each childIterator together - // @param new_zip_row input and output, will return a non-empty row if all rows from childConnectors are non-empty - // @param updateColumnMapping generates a new column name to index mapping (mColNameIdMap) if set to true + // @param new_zip_row - input and output, will be a non-empty row if all rows from childConnectors are non-empty + // @param updateColumnMapping - generates a new column name to index mapping (mColNameIdMap) if set to true // @details merge rows from iterator together. This is the main functionality for ZipOp // this function takes one row and fills it with tensors from rows fetched // from childIterators. diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 855e4609bb..f67461eee3 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -26,6 +26,7 @@ import random import uuid from enum import Enum from importlib import import_module +import threading import numpy as np from mindspore._c_dataengine import DataType, TFReaderOp, ImageFolderOp, CifarOp, MnistOp, ManifestOp, \ @@ -38,7 +39,7 @@ from .iterators import DictIterator, TupleIterator from .validators import check, check_batch, check_shuffle, check_map, check_filter, check_repeat, check_skip, check_zip, check_rename, \ check_take, check_project, check_imagefolderdatasetv2, check_mnist_cifar_dataset, check_manifestdataset, \ check_tfrecorddataset, check_vocdataset, check_celebadataset, check_minddataset, check_generatordataset, \ - check_zip_dataset, check_add_column, check_textfiledataset + check_sync_wait, check_zip_dataset, check_add_column, check_textfiledataset from ..core.datatypes import mstype_to_detype, mstypelist_to_detypelist try: @@ -139,6 +140,7 @@ class Dataset: self._batch_size = None self._num_classes = None self._repeat_count = None + self._sync = False def get_args(self): """ @@ -196,6 +198,30 @@ class Dataset: """ return BatchDataset(self, batch_size, drop_remainder, num_parallel_workers, per_batch_map, input_columns) + @check_sync_wait + def sync_wait(self, condition_name, num_batch=1, callback=None): + ''' + Add a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + + Examples: + >>> import mindspore.dataset as ds + >>> # data is an instance of Dataset object. + >>> data = data.sync_wait("callback1") + >>> data = data.batch(batch_size) + >>> for batch_data in data.create_dict_iterator(): + >>> data = data.sync_update("callback1") + ''' + return SyncWaitDataset(self, condition_name, num_batch, callback) + @check_shuffle def shuffle(self, buffer_size): """ @@ -218,6 +244,9 @@ class Dataset: Returns: ShuffleDataset, dataset shuffled. + Raises: + RuntimeError: If exist sync operators before shuffle. + Examples: >>> import mindspore.dataset as ds >>> # data is an instance of Dataset object @@ -816,6 +845,9 @@ class Dataset: self._input_indexs = value def _get_pipeline_info(self): + """ + Gets pipeline information. + """ device_iter = TupleIterator(self) self._output_shapes = device_iter.get_output_shapes() self._output_types = device_iter.get_output_types() @@ -870,6 +902,30 @@ class Dataset: return self.input[0].num_classes() return None + def get_sync_notifiers(self): + if self.input: + return self.input[0].get_sync_notifiers() + return {} + + def is_sync(self): + if self.input: + return self.input[0].is_sync() + return False + + def sync_update(self, condition_name, num_batch=None, data=None): + """ + condition_name (str): The condition name that is used to toggle sending next row + step_size (int or None): The number of steps(rows) that are released + when pass_rows is None, will update the same number as sync_wait specified + data (dict or None): The data passed to the callback + """ + notifiers_dict = self.get_sync_notifiers() + if condition_name not in notifiers_dict: + raise RuntimeError("Condition name not found") + if num_batch is not None: + num_batch *= self.get_batch_size() + notifiers_dict[condition_name](num_batch, data) + def get_batch_size(self): """ Get the size of a batch. @@ -973,6 +1029,8 @@ class BatchDataset(DatasetOp): if BatchDataset._is_ancestor_of_repeat(input_dataset): logger.warning("Repeat is located before batch, data from two epochs can be batched together.") + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + self.batch_size = batch_size self.drop_remainder = drop_remainder self.per_batch_map = per_batch_map @@ -1029,6 +1087,20 @@ class BatchDataset(DatasetOp): flag = flag | BatchDataset._is_ancestor_of_repeat(input_dataset) return flag + @staticmethod + def _update_batch_size_for_syncwait(dataset, batch_size): + """ + Utility function to notify batch size to sync_wait. + + Args: + dataset (Dataset): dataset to be checked + batchsize (int): batch size to notify + """ + if isinstance(dataset, SyncWaitDataset): + dataset.update_sync_batch_size(batch_size) + for input_dataset in dataset.input: + BatchDataset._update_batch_size_for_syncwait(input_dataset, batch_size) + class BatchInfo(CBatchInfo): """ @@ -1053,6 +1125,108 @@ class BatchInfo(CBatchInfo): """ return +class BlockReleasePair: + """ + The blocking condition class used by SyncWaitDataset + + Args: + init_release_rows (int): Number of lines to allow through the pipeline + callback (function): The callback funciton that will be called when release is called + """ + def __init__(self, init_release_rows, callback=None): + self.row_count = -init_release_rows + self.cv = threading.Condition() + self.callback = callback + self.default_rows = init_release_rows + + def __deepcopy__(self, memodict): + if id(self) in memodict: + return memodict[id(self)] + memodict[id(self)] = self + # condition variable and callback are the same, but reset the counter + self.reset() + return self + + def reset(self): + with self.cv: + self.row_count = -self.default_rows + self.cv.notify_all() + + def update_batched_size(self, batch_size): + # should only use before the pipeline creates + self.row_count *= batch_size + self.default_rows *= batch_size + + def block_func(self): + with self.cv: + self.cv.wait_for(lambda: self.row_count < 0) + self.row_count += 1 + return True + + def release_func(self, pass_rows=None, data=None): + with self.cv: + if pass_rows is None: + pass_rows = self.default_rows + self.row_count -= pass_rows + if self.callback is not None: + self.callback(data) + self.cv.notify_all() + +class SyncWaitDataset(DatasetOp): + """ + The result of adding a blocking condition to the input Dataset + + Args: + input_dataset (Dataset): Input dataset to apply flow control + num_batch (int): the number of batches without blocking at the start of each epoch + condition_name (str): The condition name that is used to toggle sending next row + callback (function): The callback funciton that will be invoked when sync_update is called + + Raises: + RuntimeError: If condition name already exists. + """ + + def __init__(self, input_dataset, condition_name, num_batch, callback=None): + super().__init__() + self.input.append(input_dataset) + input_dataset.output.append(self) + # set to the default value, waiting for the batch to update it + self._condition_name = condition_name + self._pair = BlockReleasePair(num_batch, callback) + if self._condition_name in self.input[0].get_sync_notifiers(): + raise RuntimeError("Condition name is already in use") + + def get_sync_notifiers(self): + return {**self.input[0].get_sync_notifiers(), **{self._condition_name: self._pair.release_func}} + + def is_sync(self): + return True + + def get_args(self): + args = super().get_args() + args["condition_name"] = self._condition_name + args["condition_func"] = self._pair.block_func + return args + + def update_sync_batch_size(self, batch_size): + self._pair.update_batched_size(batch_size) + + @staticmethod + def _is_ancestor_of_batch(dataset): + """ + Utility function to find the case where sync_wait is used before batch. + + Args: + dataset (Dataset): dataset to be checked + Return: + True or False + """ + if isinstance(dataset, BatchDataset): + return True + flag = False + for input_dataset in dataset.input: + flag = flag | SyncWaitDataset._is_ancestor_of_batch(input_dataset) + return flag class ShuffleDataset(DatasetOp): """ @@ -1061,6 +1235,9 @@ class ShuffleDataset(DatasetOp): Args: input_dataset (Dataset): Input Dataset to be shuffled. buffer_size (int): The size of the buffer. + + Raises: + RuntimeError: If exist sync operators before shuffle. """ def __init__(self, input_dataset, buffer_size): @@ -1069,6 +1246,8 @@ class ShuffleDataset(DatasetOp): self.input.append(input_dataset) input_dataset.output.append(self) self._input_indexs = input_dataset.input_indexs + if self.is_sync(): + raise RuntimeError("No shuffle after sync operators") def get_args(self): args = super().get_args() @@ -1335,6 +1514,9 @@ class ZipDataset(DatasetOp): """ return None + def is_sync(self): + return any([c.is_sync() for c in self.input]) + def get_args(self): args = super().get_args() return args diff --git a/mindspore/dataset/engine/iterators.py b/mindspore/dataset/engine/iterators.py index 6af6c7dba8..a8d20df5f3 100644 --- a/mindspore/dataset/engine/iterators.py +++ b/mindspore/dataset/engine/iterators.py @@ -125,6 +125,8 @@ class Iterator: op_type = OpName.MINDRECORD elif isinstance(dataset, de.BatchDataset): op_type = OpName.BATCH + elif isinstance(dataset, de.SyncWaitDataset): + op_type = OpName.BARRIER elif isinstance(dataset, de.ZipDataset): op_type = OpName.ZIP elif isinstance(dataset, de.MapDataset): diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index a68d723f1d..a8d18ab2c1 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -652,6 +652,22 @@ def check_batch(method): return new_method +def check_sync_wait(method): + """check the input arguments of sync_wait.""" + @wraps(method) + def new_method(*args, **kwargs): + param_dict = make_param_dict(method, args, kwargs) + + nreq_param_str = ['condition_name'] + nreq_param_int = ['step_size'] + + check_param_type(nreq_param_int, param_dict, int) + + check_param_type(nreq_param_str, param_dict, str) + + return method(*args, **kwargs) + + return new_method def check_shuffle(method): """check the input arguments of shuffle.""" diff --git a/tests/ut/python/dataset/test_config.py b/tests/ut/python/dataset/test_config.py index 8cabe81aaa..0c1e0073af 100644 --- a/tests/ut/python/dataset/test_config.py +++ b/tests/ut/python/dataset/test_config.py @@ -12,8 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +""" +Testing configuration manager +""" +import filecmp +import glob +import os + import mindspore.dataset as ds +import mindspore.dataset.transforms.vision.c_transforms as vision +DATA_DIR = ["../data/dataset/test_tf_file_3_images/train-0000-of-0001.data"] +SCHEMA_DIR = "../data/dataset/test_tf_file_3_images/datasetSchema.json" def test_basic(): ds.config.load('../data/dataset/declient.cfg') @@ -36,6 +46,34 @@ def test_basic(): assert ds.config.get_prefetch_size() == 4 assert ds.config.get_seed() == 5 +def test_pipeline(): + """ + Test that our configuration pipeline works when we set parameters at dataset interval + """ + data1 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(2) + data1 = data1.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data1, "testpipeline.json") + + data2 = ds.TFRecordDataset(DATA_DIR, SCHEMA_DIR, shuffle=False) + ds.config.set_num_parallel_workers(4) + data2 = data2.map(input_columns=["image"], operations=[vision.Decode(True)]) + ds.serialize(data2, "testpipeline2.json") + + # check that the generated output is different + assert (filecmp.cmp('testpipeline.json', 'testpipeline2.json')) + + # this test passes currently because our num_parallel_workers don't get updated. + + # remove generated jason files + file_list = glob.glob('*.json') + for f in file_list: + try: + os.remove(f) + except IOError: + logger.info("Error while deleting: {}".format(f)) + if __name__ == '__main__': test_basic() + test_pipeline() diff --git a/tests/ut/python/dataset/test_sync_wait.py b/tests/ut/python/dataset/test_sync_wait.py new file mode 100644 index 0000000000..277499d9ae --- /dev/null +++ b/tests/ut/python/dataset/test_sync_wait.py @@ -0,0 +1,182 @@ +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import mindspore.dataset as ds +from mindspore import log as logger +import time +import numpy as np + + +def gen(): + for i in range(100): + yield np.array(i), + + +class Augment: + def __init__(self, loss): + self.loss = loss + + def preprocess(self, input): + return input + + def update(self, data): + self.loss = data["loss"] + + +def test_simple_sync_wait(): + """ + Test simple sync wait: test sync in dataset pipeline + """ + logger.info("test_simple_sync_wait") + batch_size = 4 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_simple_shuffle_sync(): + """ + Test simple shuffle sync: test shuffle before sync + """ + logger.info("test_simple_shuffle_sync") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.shuffle(shuffle_size) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + #time.sleep(0.5) + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_two_sync(): + """ + Test two sync: dataset pipeline with with two sync_operators + """ + logger.info("test_two_sync") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + dataset = dataset.sync_wait(num_batch=2, condition_name="every 2 batches") + + dataset = dataset.batch(batch_size) + + count = 0 + for data in dataset.create_dict_iterator(): + count += 1 + data = {"loss": count} + dataset.sync_update(condition_name="every batch", data=data) + if count % 2 == 0: + dataset.sync_update(condition_name="every 2 batches") + +def test_sync_epoch(): + """ + Test sync wait with epochs: test sync with epochs in dataset pipeline + """ + logger.info("test_sync_epoch") + batch_size = 30 + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + dataset = dataset.batch(batch_size, drop_remainder=True) + + for epochs in range(3): + aug.update({"loss": 0}) + count = 0 + for data in dataset.create_dict_iterator(): + assert (data["input"][0] == count) + count += batch_size + data = {"loss": count} + dataset.sync_update(condition_name="policy", data=data) + + +def test_sync_exception_01(): + """ + Test sync: with shuffle in sync mode + """ + logger.info("test_sync_exception_01") + shuffle_size = 4 + batch_size = 10 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + dataset = dataset.sync_wait(condition_name="policy", callback=aug.update) + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.shuffle(shuffle_size) + except BaseException as e: + assert "shuffle" in str(e) + dataset = dataset.batch(batch_size) + + +def test_sync_exception_02(): + """ + Test sync: with duplicated condition name + """ + logger.info("test_sync_exception_02") + batch_size = 6 + + dataset = ds.GeneratorDataset(gen, column_names=["input"]) + + aug = Augment(0) + # notice that with our design, we need to have step_size = shuffle size + dataset = dataset.sync_wait(condition_name="every batch", callback=aug.update) + + dataset = dataset.map(input_columns=["input"], operations=[aug.preprocess]) + + try: + dataset = dataset.sync_wait(num_batch=2, condition_name="every batch") + except BaseException as e: + assert "name" in str(e) + dataset = dataset.batch(batch_size) + + +if __name__ == "__main__": + test_simple_sync_wait() + test_simple_shuffle_sync() + test_two_sync() + test_sync_exception_01() + test_sync_exception_02() + test_sync_epoch() \ No newline at end of file From 56e7a7deb55c653016e7cbecd90577ae4d62086e Mon Sep 17 00:00:00 2001 From: Amir Lashkari Date: Tue, 21 Apr 2020 19:54:32 -0400 Subject: [PATCH 074/142] Added UniformAugment + Python Augmentation Ops --- .../transforms/vision/py_transforms.py | 174 ++++++++++++++++++ .../transforms/vision/py_transforms_util.py | 157 ++++++++++++++++ tests/ut/python/dataset/test_autocontrast.py | 101 ++++++++++ tests/ut/python/dataset/test_equalize.py | 101 ++++++++++ tests/ut/python/dataset/test_invert.py | 100 ++++++++++ tests/ut/python/dataset/test_random_color.py | 102 ++++++++++ .../python/dataset/test_random_sharpness.py | 102 ++++++++++ .../ut/python/dataset/test_uniform_augment.py | 107 +++++++++++ 8 files changed, 944 insertions(+) create mode 100644 tests/ut/python/dataset/test_autocontrast.py create mode 100644 tests/ut/python/dataset/test_equalize.py create mode 100644 tests/ut/python/dataset/test_invert.py create mode 100644 tests/ut/python/dataset/test_random_color.py create mode 100644 tests/ut/python/dataset/test_random_sharpness.py create mode 100644 tests/ut/python/dataset/test_uniform_augment.py diff --git a/mindspore/dataset/transforms/vision/py_transforms.py b/mindspore/dataset/transforms/vision/py_transforms.py index 8d81f8f3b0..51bea80b21 100644 --- a/mindspore/dataset/transforms/vision/py_transforms.py +++ b/mindspore/dataset/transforms/vision/py_transforms.py @@ -1312,3 +1312,177 @@ class HsvToRgb: rgb_imgs (numpy.ndarray), Numpy RGB image with same shape of hsv_imgs. """ return util.hsv_to_rgbs(hsv_imgs, self.is_hwc) + + +class RandomColor: + """ + Adjust the color of the input PIL image by a random degree. + + Args: + degrees (sequence): Range of random color adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.RandomColor(0.5,1.5), + >>> py_transforms.ToTensor()]) + """ + + def __init__(self, degrees=(0.1, 1.9)): + self.degrees = degrees + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be color adjusted. + + Returns: + img (PIL Image), Color adjusted image. + """ + + return util.random_color(img, self.degrees) + +class RandomSharpness: + """ + Adjust the sharpness of the input PIL image by a random degree. + + Args: + degrees (sequence): Range of random sharpness adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.RandomColor(0.5,1.5), + >>> py_transforms.ToTensor()]) + + """ + + def __init__(self, degrees=(0.1, 1.9)): + self.degrees = degrees + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be sharpness adjusted. + + Returns: + img (PIL Image), Color adjusted image. + """ + + return util.random_sharpness(img, self.degrees) + + +class AutoContrast: + """ + Automatically maximize the contrast of the input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.AutoContrast(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be augmented with AutoContrast. + + Returns: + img (PIL Image), Augmented image. + """ + + return util.auto_contrast(img) + + +class Invert: + """ + Invert colors of input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.Invert(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be color Inverted. + + Returns: + img (PIL Image), Color inverted image. + """ + + return util.invert_color(img) + + +class Equalize: + """ + Equalize the histogram of input PIL image. + + Examples: + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.Equalize(), + >>> py_transforms.ToTensor()]) + + """ + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be equalized. + + Returns: + img (PIL Image), Equalized image. + """ + + return util.equalize(img) + + +class UniformAugment: + """ + Uniformly select and apply a number of transforms sequentially from + a list of transforms. Randomly assigns a probability to each transform for + each image to decide whether apply it or not. + + Args: + transforms (list): List of transformations to be chosen from to apply. + num_ops (int, optional): number of transforms to sequentially apply (default=2). + + Examples: + >>> transforms_list = [py_transforms.CenterCrop(64), + >>> py_transforms.RandomColor(), + >>> py_transforms.RandomSharpness(), + >>> py_transforms.RandomRotation(30)] + >>> py_transforms.ComposeOp([py_transforms.Decode(), + >>> py_transforms.UniformAugment(transforms_list), + >>> py_transforms.ToTensor()]) + """ + + def __init__(self, transforms, num_ops=2): + self.transforms = transforms + self.num_ops = num_ops + + def __call__(self, img): + """ + Call method. + + Args: + img (PIL Image): Image to be applied transformation. + + Returns: + img (PIL Image), Transformed image. + """ + return util.uniform_augment(img, self.transforms, self.num_ops) diff --git a/mindspore/dataset/transforms/vision/py_transforms_util.py b/mindspore/dataset/transforms/vision/py_transforms_util.py index 10c71bbe38..54fb4c8274 100644 --- a/mindspore/dataset/transforms/vision/py_transforms_util.py +++ b/mindspore/dataset/transforms/vision/py_transforms_util.py @@ -1408,3 +1408,160 @@ def hsv_to_rgbs(np_hsv_imgs, is_hwc): if batch_size == 0: return hsv_to_rgb(np_hsv_imgs, is_hwc) return np.array([hsv_to_rgb(img, is_hwc) for img in np_hsv_imgs]) + + +def random_color(img, degrees): + + """ + Adjust the color of the input PIL image by a random degree. + + Args: + img (PIL Image): Image to be color adjusted. + degrees (sequence): Range of random color adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Returns: + img (PIL Image), Color adjusted image. + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence length 2.") + if degrees[0] < 0: + raise ValueError("Degree value must be non-negative.") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") + + else: + raise TypeError("Degrees must be a sequence in (min,max) format.") + + v = (degrees[1] - degrees[0]) * random.random() + degrees[0] + return ImageEnhance.Color(img).enhance(v) + + +def random_sharpness(img, degrees): + + """ + Adjust the sharpness of the input PIL image by a random degree. + + Args: + img (PIL Image): Image to be sharpness adjusted. + degrees (sequence): Range of random sharpness adjustment degrees. + It should be in (min, max) format (default=(0.1,1.9)). + + Returns: + img (PIL Image), Sharpness adjusted image. + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + if isinstance(degrees, (list, tuple)): + if len(degrees) != 2: + raise ValueError("Degrees must be a sequence length 2.") + if degrees[0] < 0: + raise ValueError("Degree value must be non-negative.") + if degrees[0] > degrees[1]: + raise ValueError("Degrees should be in (min,max) format. Got (max,min).") + + else: + raise TypeError("Degrees must be a sequence in (min,max) format.") + + v = (degrees[1] - degrees[0]) * random.random() + degrees[0] + return ImageEnhance.Sharpness(img).enhance(v) + + +def auto_contrast(img): + + """ + Automatically maximize the contrast of the input PIL image. + + Args: + img (PIL Image): Image to be augmented with AutoContrast. + + Returns: + img (PIL Image), Augmented image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.autocontrast(img) + + +def invert_color(img): + + """ + Invert colors of input PIL image. + + Args: + img (PIL Image): Image to be color inverted. + + Returns: + img (PIL Image), Color inverted image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.invert(img) + + +def equalize(img): + + """ + Equalize the histogram of input PIL image. + + Args: + img (PIL Image): Image to be equalized + + Returns: + img (PIL Image), Equalized image. + + """ + + if not is_pil(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return ImageOps.equalize(img) + + +def uniform_augment(img, transforms, num_ops): + + """ + Uniformly select and apply a number of transforms sequentially from + a list of transforms. Randomly assigns a probability to each transform for + each image to decide whether apply it or not. + + Args: + img: Image to be applied transformation. + transforms (list): List of transformations to be chosen from to apply. + num_ops (int): number of transforms to sequentially aaply. + + Returns: + img, Transformed image. + """ + + if transforms is None: + raise ValueError("transforms is not provided.") + if not isinstance(transforms, list): + raise ValueError("The transforms needs to be a list.") + + if not isinstance(num_ops, int): + raise ValueError("Number of operations should be a positive integer.") + if num_ops < 1: + raise ValueError("Number of operators should equal or greater than one.") + + for _ in range(num_ops): + AugmentOp = random.choice(transforms) + pr = random.random() + if random.random() < pr: + img = AugmentOp(img.copy()) + transforms.remove(AugmentOp) + + return img diff --git a/tests/ut/python/dataset/test_autocontrast.py b/tests/ut/python/dataset/test_autocontrast.py new file mode 100644 index 0000000000..7dba2f21f6 --- /dev/null +++ b/tests/ut/python/dataset/test_autocontrast.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_auto_contrast): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_auto_contrast) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_auto_contrast[i]) + plt.title("DE AutoContrast image") + + plt.show() + + +def test_auto_contrast(plot=False): + """ + Test AutoContrast + """ + logger.info("Test AutoContrast") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # AutoContrast Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_auto_contrast = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.AutoContrast(), + F.ToTensor()]) + + ds_auto_contrast = ds.map(input_columns="image", + operations=transforms_auto_contrast()) + + ds_auto_contrast = ds_auto_contrast.batch(512) + + for idx, (image,label) in enumerate(ds_auto_contrast): + if idx == 0: + images_auto_contrast = np.transpose(image, (0, 2,3,1)) + else: + images_auto_contrast = np.append(images_auto_contrast, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_auto_contrast[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_auto_contrast) + + +if __name__ == "__main__": + test_auto_contrast(plot=True) + diff --git a/tests/ut/python/dataset/test_equalize.py b/tests/ut/python/dataset/test_equalize.py new file mode 100644 index 0000000000..077c316d67 --- /dev/null +++ b/tests/ut/python/dataset/test_equalize.py @@ -0,0 +1,101 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_equalize): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_equalize) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_equalize[i]) + plt.title("DE Color Equalized image") + + plt.show() + + +def test_equalize(plot=False): + """ + Test Equalize + """ + logger.info("Test Equalize") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Color Equalized Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_equalize = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.Equalize(), + F.ToTensor()]) + + ds_equalize = ds.map(input_columns="image", + operations=transforms_equalize()) + + ds_equalize = ds_equalize.batch(512) + + for idx, (image,label) in enumerate(ds_equalize): + if idx == 0: + images_equalize = np.transpose(image, (0, 2,3,1)) + else: + images_equalize = np.append(images_equalize, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_equalize[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_equalize) + + +if __name__ == "__main__": + test_equalize(plot=True) + diff --git a/tests/ut/python/dataset/test_invert.py b/tests/ut/python/dataset/test_invert.py new file mode 100644 index 0000000000..a1bfd63431 --- /dev/null +++ b/tests/ut/python/dataset/test_invert.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +def visualize(image_original, image_invert): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_invert) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_invert[i]) + plt.title("DE Color Inverted image") + + plt.show() + + +def test_invert(plot=False): + """ + Test Invert + """ + logger.info("Test Invert") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Color Inverted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_invert = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.Invert(), + F.ToTensor()]) + + ds_invert = ds.map(input_columns="image", + operations=transforms_invert()) + + ds_invert = ds_invert.batch(512) + + for idx, (image,label) in enumerate(ds_invert): + if idx == 0: + images_invert = np.transpose(image, (0, 2,3,1)) + else: + images_invert = np.append(images_invert, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_invert[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_invert) + + +if __name__ == "__main__": + test_invert(plot=True) + diff --git a/tests/ut/python/dataset/test_random_color.py b/tests/ut/python/dataset/test_random_color.py new file mode 100644 index 0000000000..9472b7e35a --- /dev/null +++ b/tests/ut/python/dataset/test_random_color.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_random_color): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_random_color) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_random_color[i]) + plt.title("DE Random Color image") + + plt.show() + + +def test_random_color(degrees=(0.1,1.9), plot=False): + """ + Test RandomColor + """ + logger.info("Test RandomColor") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Random Color Adjusted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_random_color = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.RandomColor(degrees=degrees), + F.ToTensor()]) + + ds_random_color = ds.map(input_columns="image", + operations=transforms_random_color()) + + ds_random_color = ds_random_color.batch(512) + + for idx, (image,label) in enumerate(ds_random_color): + if idx == 0: + images_random_color = np.transpose(image, (0, 2,3,1)) + else: + images_random_color = np.append(images_random_color, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_random_color[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_random_color) + + +if __name__ == "__main__": + test_random_color() + test_random_color(plot=True) + test_random_color(degrees=(0.5,1.5), plot=True) diff --git a/tests/ut/python/dataset/test_random_sharpness.py b/tests/ut/python/dataset/test_random_sharpness.py new file mode 100644 index 0000000000..949a658597 --- /dev/null +++ b/tests/ut/python/dataset/test_random_sharpness.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + + +def visualize(image_original, image_random_sharpness): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_random_sharpness) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_random_sharpness[i]) + plt.title("DE Random Sharpness image") + + plt.show() + + +def test_random_sharpness(degrees=(0.1,1.9), plot=False): + """ + Test RandomSharpness + """ + logger.info("Test RandomSharpness") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # Random Sharpness Adjusted Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_random_sharpness = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.RandomSharpness(degrees=degrees), + F.ToTensor()]) + + ds_random_sharpness = ds.map(input_columns="image", + operations=transforms_random_sharpness()) + + ds_random_sharpness = ds_random_sharpness.batch(512) + + for idx, (image,label) in enumerate(ds_random_sharpness): + if idx == 0: + images_random_sharpness = np.transpose(image, (0, 2,3,1)) + else: + images_random_sharpness = np.append(images_random_sharpness, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_random_sharpness[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_random_sharpness) + + +if __name__ == "__main__": + test_random_sharpness() + test_random_sharpness(plot=True) + test_random_sharpness(degrees=(0.5,1.5), plot=True) diff --git a/tests/ut/python/dataset/test_uniform_augment.py b/tests/ut/python/dataset/test_uniform_augment.py new file mode 100644 index 0000000000..ce0490336e --- /dev/null +++ b/tests/ut/python/dataset/test_uniform_augment.py @@ -0,0 +1,107 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import numpy as np +import matplotlib.pyplot as plt +from mindspore import log as logger +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.vision.py_transforms as F + +DATA_DIR = "../data/dataset/testImageNetData/train/" + +def visualize(image_original, image_ua): + """ + visualizes the image using DE op and Numpy op + """ + num = len(image_ua) + for i in range(num): + plt.subplot(2, num, i + 1) + plt.imshow(image_original[i]) + plt.title("Original image") + + plt.subplot(2, num, i + num + 1) + plt.imshow(image_ua[i]) + plt.title("DE UniformAugment image") + + plt.show() + + +def test_uniform_augment(plot=False, num_ops=2): + """ + Test UniformAugment + """ + logger.info("Test UniformAugment") + + # Original Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transforms_original = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.ToTensor()]) + + ds_original = ds.map(input_columns="image", + operations=transforms_original()) + + ds_original = ds_original.batch(512) + + for idx, (image,label) in enumerate(ds_original): + if idx == 0: + images_original = np.transpose(image, (0, 2,3,1)) + else: + images_original = np.append(images_original, + np.transpose(image, (0, 2,3,1)), + axis=0) + + # UniformAugment Images + ds = de.ImageFolderDatasetV2(dataset_dir=DATA_DIR, shuffle=False) + + transform_list = [F.RandomRotation(45), + F.RandomColor(), + F.RandomSharpness(), + F.Invert(), + F.AutoContrast(), + F.Equalize()] + + transforms_ua = F.ComposeOp([F.Decode(), + F.Resize((224,224)), + F.UniformAugment(transforms=transform_list, num_ops=num_ops), + F.ToTensor()]) + + ds_ua = ds.map(input_columns="image", + operations=transforms_ua()) + + ds_ua = ds_ua.batch(512) + + for idx, (image,label) in enumerate(ds_ua): + if idx == 0: + images_ua = np.transpose(image, (0, 2,3,1)) + else: + images_ua = np.append(images_ua, + np.transpose(image, (0, 2,3,1)), + axis=0) + + num_samples = images_original.shape[0] + mse = np.zeros(num_samples) + for i in range(num_samples): + mse[i] = np.mean((images_ua[i]-images_original[i])**2) + logger.info("MSE= {}".format(str(np.mean(mse)))) + + if plot: + visualize(images_original, images_ua) + + +if __name__ == "__main__": + test_uniform_augment(num_ops=1) + From 4b19409a6f61354f5abf9bc8b260987dfb1d4d79 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Wed, 22 Apr 2020 09:13:17 +0800 Subject: [PATCH 075/142] add example for maxpool --- mindspore/ops/operations/nn_ops.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 1fb65e3b76..23ddd9f021 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -913,6 +913,11 @@ class MaxPool(_Pool): Outputs: Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + Examples: + >>> input_tensor = Tensor(np.arange(1*3*3*4).reshape((1,3,3,4)),mindspore.float32) + >>> maxpool_op = P.MaxPool(padding="VALID", ksize=2, strides=1) + >>> output_tensor = maxpool_op(input_tensor) """ @prim_attr_register @@ -959,6 +964,11 @@ class MaxPoolWithArgmax(_Pool): - **output** (Tensor) - Maxpooling result, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. - **mask** (Tensor) - Max values' index represented by the mask. + + Examples: + >>> input_tensor = Tensor(np.arange(1*3*3*4).reshape((1,3,3,4)),mindspore.float32) + >>> maxpool_arg_op = P.MaxPoolWithArgmax(padding="VALID", ksize=2, strides=1) + >>> output_tensor, argmax = maxpool_arg_op(input_tensor) """ def __init__(self, ksize=1, strides=1, padding="valid"): super(MaxPoolWithArgmax, self).__init__(ksize, strides, padding) From 5072ae8f4c5c8fe9f07efbd50d6cac363099ba98 Mon Sep 17 00:00:00 2001 From: dengwentao Date: Fri, 17 Apr 2020 15:57:01 +0800 Subject: [PATCH 076/142] modify error log for opinfo --- mindspore/ccsrc/kernel/oplib/oplib.cc | 38 +++++++++++++-------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/mindspore/ccsrc/kernel/oplib/oplib.cc b/mindspore/ccsrc/kernel/oplib/oplib.cc index cd0f843867..f5f2e1601b 100644 --- a/mindspore/ccsrc/kernel/oplib/oplib.cc +++ b/mindspore/ccsrc/kernel/oplib/oplib.cc @@ -83,13 +83,13 @@ bool OpLib::RegOp(const std::string &json_string, const std::string &impl_path) OpImplyType imply_type = kAICPU; ret = DecodeOpInfo(op_json, imply_type, impl_path); } else { - MS_LOG(DEBUG) << "Not support imply_type"; + MS_LOG(ERROR) << "Not support imply_type"; } if (!ret) { - MS_LOG(DEBUG) << "RegOp failed: opname:" << op_name << "imply_type" << imply_type_string; + MS_LOG(ERROR) << "RegOp failed: op_name: " << op_name << " imply_type " << imply_type_string; } } catch (const std::exception &e) { - MS_LOG(DEBUG) << "get op_json elements failed:" << e.what(); + MS_LOG(ERROR) << "get op json elements failed: " << e.what(); } return ret; } @@ -122,7 +122,7 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI auto attrs = obj.at(kAttr); for (const auto &attr : attrs) { if (!DecodeAttr(attr, imply_type, op_info)) { - MS_LOG(DEBUG) << "DecodeAttr Failed"; + MS_LOG(ERROR) << "DecodeAttr Failed"; return false; } } @@ -133,23 +133,23 @@ bool OpLib::DecodeOpInfo(const nlohmann::json &obj, const mindspore::kernel::OpI auto inputs = obj.at(kIputs); for (const auto &input : inputs) { if (!DecodeInputOutput(input, imply_type, kInput, op_info, dtype_format)) { - MS_LOG(DEBUG) << "DecodeInputOutput Failed"; + MS_LOG(ERROR) << "DecodeInputOutput Failed"; return false; } } auto outputs = obj.at(kOutputs); for (const auto &output : outputs) { if (!DecodeInputOutput(output, imply_type, kOutput, op_info, dtype_format)) { - MS_LOG(DEBUG) << "DecodeInputOutput Failed"; + MS_LOG(ERROR) << "DecodeInputOutput Failed"; return false; } } if (!GetRefInfo(op_info)) { - MS_LOG(DEBUG) << "GetRefInfo Failed"; + MS_LOG(ERROR) << "GetRefInfo Failed"; return false; } if (!CheckRepetition(op_info)) { - MS_LOG(DEBUG) << "CheckRepetition Failed"; + MS_LOG(ERROR) << "CheckRepetition Failed"; return false; } op_info_.push_back(op_info); @@ -176,7 +176,7 @@ bool OpLib::DecodeAttr(const nlohmann::json &obj, const OpImplyType imply_type, } op_info->add_attrs_ptr(op_attr); } catch (const std::exception &e) { - MS_LOG(DEBUG) << "DecodeAttr failed:" << e.what(); + MS_LOG(ERROR) << "DecodeAttr failed:" << e.what(); ret = false; } return ret; @@ -219,8 +219,8 @@ bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply op_io->set_formats(obj.at(kFormat)); } if (op_io->dtypes().size() != op_io->formats().size()) { - MS_LOG(DEBUG) << "op" << op_io->name() << "dtype size:" << op_io->dtypes() - << "is not equal to format size:" << op_io->formats(); + MS_LOG(ERROR) << "op " << op_io->name() << " dtype size: " << op_io->dtypes() + << " is not equal to format size: " << op_io->formats(); return false; } if (obj.find(kParamType) != obj.end()) { @@ -244,7 +244,7 @@ bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply op_info->add_outputs_ptr(op_io); } } catch (const std::exception &e) { - MS_LOG(DEBUG) << "DecodeInputOutput failed" << e.what(); + MS_LOG(ERROR) << "DecodeInputOutput failed" << e.what(); ret = false; } return ret; @@ -256,8 +256,8 @@ std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType im bool is_gpu = (context->device_target() == kGPUDevice); if ((is_gpu && (imply_type == kTBE || imply_type == kAICPU)) || (!is_gpu && (imply_type != kTBE && imply_type != kAICPU))) { - MS_LOG(ERROR) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) - << ", current op num:" << op_info_.size(); + MS_LOG(ERROR) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); return nullptr; } for (const auto &op_info : op_info_) { @@ -266,8 +266,8 @@ std::shared_ptr OpLib::FindOp(const std::string &op_name, OpImplyType im return op_info; } } - MS_LOG(DEBUG) << "FindOp failed: opname:" << op_name << ", imply_type:" << ImplTypeToStr(imply_type) - << ", current op num:" << op_info_.size(); + MS_LOG(DEBUG) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type) + << ", current op num: " << op_info_.size(); return nullptr; } @@ -281,7 +281,7 @@ bool OpLib::GetRefInfo(const std::shared_ptr &op_info) { const auto &in_name = input_infos[in_index]->name(); if (out_name == in_name) { if (op_info->has_ref_index(out_index)) { - MS_LOG(DEBUG) << "The out_index" << out_index << "is already in ref_info"; + MS_LOG(ERROR) << "The out_index " << out_index << " is already in ref_info"; return false; } op_info->add_ref_pair(out_index, in_index); @@ -299,8 +299,8 @@ bool OpLib::CheckRepetition(const std::shared_ptr &op_info) { MS_EXCEPTION_IF_NULL(exist_op_info); if (exist_op_info->op_name() == op_info->op_name() && exist_op_info->imply_type() == op_info->imply_type() && exist_op_info->impl_path() != op_info->impl_path()) { - MS_LOG(DEBUG) << "Has already exist, drop the latter one, op name:" << op_info->op_name() - << "op type:" << ImplTypeToStr(op_info->imply_type()); + MS_LOG(ERROR) << "Op has already exist, please use other name, op name: " << op_info->op_name() + << " op type: " << ImplTypeToStr(op_info->imply_type()); return false; } } From c874e2d484007b833b8dafb23881a43f90a7dd5c Mon Sep 17 00:00:00 2001 From: liuxiao Date: Fri, 17 Apr 2020 16:50:53 +0800 Subject: [PATCH 077/142] Add L2Loss op for VM --- mindspore/ops/_grad/grad_nn_ops.py | 11 +++++++ mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/l2_loss.py | 44 ++++++++++++++++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/nn_ops.py | 35 ++++++++++++++++++++ tests/ut/python/ops/test_ops.py | 8 +++++ 6 files changed, 101 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/tbe/l2_loss.py diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index ae730d78a7..887c2a7528 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -456,6 +456,17 @@ def get_bprop_smooth_l1_loss(self): return bprop +@bprop_getters.register(P.L2Loss) +def get_bprop_l2_loss(self): + """Grad definition for `L2Loss` operation.""" + + def bprop(x, out, dout): + dx = x * dout + return (dx,) + + return bprop + + @bprop_getters.register(P.PReLU) def get_bprop_prelu(self): """Grad definition for `PReLU` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 2cffc37491..37da184869 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -117,6 +117,7 @@ from .layer_norm_beta_gamma_backprop import _layer_norm_beta_gamma_backprop_tbe from .layer_norm import _layer_norm_tbe from .layer_norm_grad import _layer_norm_grad_tbe from .layer_norm_x_backprop import _layer_norm_x_backprop_tbe +from .l2_loss import _l2_loss_tbe from .square_sum_v1 import _square_sum_v1_tbe from .square_sum_v2 import _square_sum_v2_tbe from .confusion_transpose_d import _confusion_transpose_d_tbe diff --git a/mindspore/ops/_op_impl/tbe/l2_loss.py b/mindspore/ops/_op_impl/tbe/l2_loss.py new file mode 100644 index 0000000000..7d1394ad64 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/l2_loss.py @@ -0,0 +1,44 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""L2Loss op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +l2_loss_op_info = TBERegOp("L2Loss") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("l2_loss.so") \ + .compute_cost(10) \ + .kernel_name("l2_loss") \ + .partial_flag(True) \ + .input(0, "x", None, "required", None) \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_FracNZ, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_Default) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_Default) \ + .dtype_format(DataType.F32_FracNZ, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_Default) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(l2_loss_op_info) +def _l2_loss_tbe(): + """L2Loss TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 1f0ee8a04d..2860690b91 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -55,7 +55,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, DropoutDoMask, DropoutGenMask, Flatten, FusedBatchNorm, Gelu, Elu, - GetNext, L2Normalize, LayerNorm, + GetNext, L2Normalize, LayerNorm, L2Loss, LogSoftmax, MaxPool, ExtractImagePatches, AvgPool, Conv2DBackpropInput, @@ -167,6 +167,7 @@ __all__ = [ 'FloatStatus', 'Reciprocal', 'SmoothL1Loss', + 'L2Loss', 'ReduceAll', 'ScalarToArray', 'ScalarToTensor', diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index ed9f0742e8..6f39fdd2ae 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1332,6 +1332,41 @@ class SmoothL1Loss(PrimitiveWithInfer): return prediction +class L2Loss(PrimitiveWithInfer): + """ + Calculates half of the L2 norm of a tensor without using the `sqrt`. + + Set `input_x` as x and output as loss. + + .. math:: + loss = sum(x ** 2) / 2 + + Inputs: + - **input_x** (Tensor) - A input Tensor. + + Outputs: + Tensor. Has the same dtype as `input_x`. The output tensor is the value of loss which is a scalar tensor. + + Examples + >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float16) + >>> l2_loss = P.L2Loss() + >>> l2_loss(input_x) + 7.0 + """ + @prim_attr_register + def __init__(self): + """init L2Loss""" + + def infer_shape(self, input_x): + loss_shape = [] + return loss_shape + + def infer_dtype(self, x_type): + validator.check_subclass("x_type", x_type, mstype.tensor, self.name) + validator.check_tensor_type_same({'x_type': x_type}, [mstype.double, mstype.float_, mstype.float16], self.name) + return x_type + + class SGD(PrimitiveWithInfer): """ Computes stochastic gradient descent (optionally with momentum). diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 1a79935467..1bd3a2e438 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -871,6 +871,14 @@ test_case_nn_ops = [ 'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3], [3, 3]], 'desc_bprop': [3, 3], 'skip': ['backward']}), + ('L2Loss_1', { + 'block': P.L2Loss(), + 'desc_inputs': [Tensor(np.array([1, 2, 3, 4]), mstype.float16)], + 'desc_bprop': []}), + ('L2Loss_2', { + 'block': P.L2Loss(), + 'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)], + 'desc_bprop': []}), ] test_case_array_ops = [ From 94a455dacef300853bcfce266b05710aceabe8ff Mon Sep 17 00:00:00 2001 From: caifubi Date: Tue, 21 Apr 2020 21:35:08 +0800 Subject: [PATCH 078/142] insert profiling kernel for hccl automaticly --- .../ascend/profiling/profiling_utils.cc | 66 +++++++++++++++++-- .../device/ascend/profiling/profiling_utils.h | 6 +- 2 files changed, 63 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc index aa71aa0566..7960a08938 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.cc @@ -39,13 +39,9 @@ ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNullexecution_order(); ProfilingTraceInfo profiling_trace; profiling_trace.trace_begin = GetTraceBegin(cnode_exec_order); - profiling_trace.trace_bp_end = GetTraceBpEnd(); + profiling_trace.trace_bp_end = GetTraceBpEnd(cnode_exec_order); profiling_trace.trace_netoutput = GetTraceNetoutput(cnode_exec_order); - MS_LOG(INFO) << "[profiling] trace_begin:" << profiling_trace.trace_begin - << " trace_bp_end:" << profiling_trace.trace_bp_end - << " trace_netoutput:" << profiling_trace.trace_netoutput; - for (uint32_t i = 1; i <= kMaxProfilingNodeNum; ++i) { std::string env_str = std::string(kCustomNode) + std::to_string(i); const char *node_full_name = std::getenv(env_str.c_str()); @@ -56,9 +52,25 @@ ProfilingTraceInfo ProfilingUtils::GetProfilingTraceFromEnv(NotNull &cnode_exec_order, + NotNull profiling_trace) { + for (const auto &node : cnode_exec_order) { + if (AnfAlgo::IsCommunicationOp(node)) { + MS_EXCEPTION_IF_NULL(node); + profiling_trace->trace_custom_node.insert(node->fullname_with_scope()); + MS_LOG(INFO) << "[profiling]Get hccl node:" << node->fullname_with_scope(); + } + } +} + std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exec_order) { const char *trace_begin = std::getenv(kFpStartNode); auto &first_cnode = cnode_exec_order.front(); @@ -66,9 +78,45 @@ std::string ProfilingUtils::GetTraceBegin(const std::vector &cnode_exe return trace_begin == nullptr ? first_cnode->fullname_with_scope() : std::string(trace_begin); } -std::string ProfilingUtils::GetTraceBpEnd() { +std::string ProfilingUtils::GetTraceBpEnd(const std::vector &cnode_exec_order) { const char *trace_bp_end = std::getenv(kBpEndNode); - return trace_bp_end == nullptr ? "" : std::string(trace_bp_end); + + if (trace_bp_end != nullptr) { + return std::string(trace_bp_end); + } + std::string bp_end_str = ""; + // Contain hccl kernel + auto iter = cnode_exec_order.rbegin(); + while (iter != cnode_exec_order.rend()) { + if (AnfAlgo::IsCommunicationOp(*iter)) { + // store communication op input nodes' name + std::set ar_input_node_names; + for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(*iter); ++i) { + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(*iter, i); + auto input_node = input_node_with_index.first; + ar_input_node_names.insert(input_node->fullname_with_scope()); + } + // start from previous node + ++iter; + // find input names in previous node + while (iter != cnode_exec_order.rend()) { + if (ar_input_node_names.find((*iter)->fullname_with_scope()) != ar_input_node_names.end()) { + bp_end_str = (*iter)->fullname_with_scope(); + break; + } + ++iter; + } + break; + } + ++iter; + } + + if (bp_end_str.empty()) { + auto last_cnode = cnode_exec_order.back(); + MS_EXCEPTION_IF_NULL(last_cnode); + bp_end_str = last_cnode->fullname_with_scope(); + } + return bp_end_str; } std::string ProfilingUtils::GetTraceNetoutput(const std::vector &cnode_exec_order) { @@ -109,6 +157,7 @@ void ProfilingUtils::ProfilingTraceFpStart(const mindspore::AnfNodePtr &anf_node NotNull graph_ptr, NotNull *> kernel_list) { if (profiling_trace_info.trace_begin == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match FpStart:" << profiling_trace_info.trace_begin; auto job_id = ProfilingManager::GetInstance().GetJobId(); ProfilingContent job_profiling_context = {false, job_id, 0}; auto job_profiling_node = CreateProfilingCNodeWithStream(anf_node, job_profiling_context, graph_ptr); @@ -137,6 +186,7 @@ void ProfilingUtils::ProfilingCustomOp(const AnfNodePtr &anf_node, const Profili if (iter == profiling_trace_info.trace_custom_node.end()) { return; } + MS_LOG(INFO) << "Profiling Match CustomOp:" << anf_node->fullname_with_scope(); // custom op profiling job start from 3. ProfilingContent front_profiling_content = {false, 2 * custom_node_index_ + 1, 0}; CNodePtr front_node = CreateProfilingCNodeWithStream(anf_node, front_profiling_content, graph_ptr); @@ -153,6 +203,7 @@ void ProfilingUtils::ProfilingTraceBpEnd(const AnfNodePtr &anf_node, const Profi NotNull *> kernel_list) { MS_EXCEPTION_IF_NULL(anf_node); if (profiling_trace_info.trace_bp_end == anf_node->fullname_with_scope()) { + MS_LOG(INFO) << "Profiling Match BpEnd:" << profiling_trace_info.trace_bp_end; ProfilingContent bp_end_profiling_content = {false, kProfilingBpEndLogId, 0}; CNodePtr bp_end_node = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); kernel_list->emplace_back(bp_end_node); @@ -165,6 +216,7 @@ void ProfilingUtils::ProfilingTraceEnd(const AnfNodePtr &anf_node, const Profili MS_EXCEPTION_IF_NULL(anf_node); auto full_scope_name = anf_node->fullname_with_scope(); if (profiling_trace_info.trace_netoutput == full_scope_name) { + MS_LOG(INFO) << "Profiling Match IterEnd:" << profiling_trace_info.trace_netoutput; ProfilingContent bp_end_profiling_content = {true, kProfilingIterEndLogId, 0}; CNodePtr bp_kernel_ptr = CreateProfilingCNodeWithStream(anf_node, bp_end_profiling_content, graph_ptr); kernel_list->emplace_back(bp_kernel_ptr); diff --git a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h index c59e856249..f9f08c9d3f 100644 --- a/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h +++ b/mindspore/ccsrc/device/ascend/profiling/profiling_utils.h @@ -43,7 +43,7 @@ struct ProfilingTraceInfo { // 3. insert profiling_trace_bp_end. // 4. insert profiling_trace_net_output if profiling_trace_bp_end is not empty. - bool IsValid() const { return !(trace_begin.empty() || trace_bp_end.empty() || trace_netoutput.empty()); } + bool IsValid() const { return !(trace_begin.empty() || trace_netoutput.empty()); } }; struct ProfilingContent { @@ -109,8 +109,10 @@ class ProfilingUtils { static CNodePtr CreateProfilingCNodeWithStream(const AnfNodePtr &anf_node, const ProfilingContent &profiling_content, NotNull graph_ptr); static std::string GetTraceBegin(const std::vector &cnode_exec_order); - static std::string GetTraceBpEnd(); + static std::string GetTraceBpEnd(const std::vector &cnode_exec_order); static std::string GetTraceNetoutput(const std::vector &cnode_exec_order); + static void GetTraceHccl(const std::vector &cnode_exec_order, + NotNull profiling_trace); // graph id --> (kernel name list) static std::unordered_map> graph_kernel_name_; From 18d8a1d2d3711e7cb4c0012f240e6f51a358c780 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Wed, 22 Apr 2020 02:24:48 +0000 Subject: [PATCH 079/142] support default log level on windows 10 --- build.bat | 4 +++- mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc | 2 ++ mindspore/ccsrc/utils/log_adapter.cc | 6 ++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/build.bat b/build.bat index 76d7f19262..ddb2e8affe 100644 --- a/build.bat +++ b/build.bat @@ -31,6 +31,7 @@ cd %CD%/mindspore cmake -DCMAKE_BUILD_TYPE=Release -DENABLE_CPU=ON -DENABLE_MINDDATA=ON -DUSE_GLOG=ON -G "CodeBlocks - MinGW Makefiles" ../.. IF NOT %errorlevel% == 0 ( + echo "cmake fail." goto run_fail ) @@ -40,6 +41,7 @@ IF "%1%" == "" ( cmake --build . --target package -- -j%1% ) IF NOT %errorlevel% == 0 ( + echo "build fail." goto run_fail ) @@ -49,6 +51,6 @@ goto run_eof :run_fail cd %BASEPATH% - echo "build fail." + set errorlevel=1 :run_eof diff --git a/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc index 5aba329e12..77a3345344 100644 --- a/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc +++ b/mindspore/ccsrc/device/cpu/cpu_kernel_factory.cc @@ -31,7 +31,9 @@ CPUKernelFactory &CPUKernelFactory::Get() { void CPUKernelFactory::Register(const std::string &kernel_name, CPUKernelCreator &&kernel_creator) { if (kernel_creators_.find(kernel_name) == kernel_creators_.end()) { (void)kernel_creators_.emplace(kernel_name, kernel_creator); +#if !defined(_WIN32) && !defined(_WIN64) MS_LOG(DEBUG) << "CPUKernelFactory register operator: " << kernel_name; +#endif } } diff --git a/mindspore/ccsrc/utils/log_adapter.cc b/mindspore/ccsrc/utils/log_adapter.cc index 4c197a0bdf..0cd9b64a9b 100644 --- a/mindspore/ccsrc/utils/log_adapter.cc +++ b/mindspore/ccsrc/utils/log_adapter.cc @@ -229,13 +229,19 @@ static void InitMsLogLevel() { extern "C" { // shared lib init hook +#if defined(_WIN32) || defined(_WIN64) +__attribute__((constructor)) void mindspore_log_init(void) { +#else void mindspore_log_init(void) { +#endif #ifdef USE_GLOG // do not use glog predefined log prefix FLAGS_log_prefix = false; static bool is_glog_initialzed = false; if (!is_glog_initialzed) { +#if !defined(_WIN32) && !defined(_WIN64) google::InitGoogleLogging("mindspore"); +#endif is_glog_initialzed = true; } // set default log level to WARNING From 2ed4ad0f2aac1b29cae2e1ad35b119da39896ddc Mon Sep 17 00:00:00 2001 From: huanghui Date: Wed, 22 Apr 2020 09:52:45 +0800 Subject: [PATCH 080/142] optimize_dependece pass enhance --- .../pre_activate/pass/optimize_dependence.cc | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc index db32354abf..86a90a4dfe 100644 --- a/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc +++ b/mindspore/ccsrc/pre_activate/pass/optimize_dependence.cc @@ -28,8 +28,7 @@ namespace mindspore { namespace opt { constexpr auto kSingleInputIndex = 1; namespace { -AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(func_graph); +AnfNodePtr GetReplaceNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { return nullptr; @@ -41,15 +40,6 @@ AnfNodePtr GetReplaceNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node if (op_name != kTransDataOpName && op_name != prim::kPrimCast->name()) { return nullptr; } - auto manager = func_graph->manager(); - MS_EXCEPTION_IF_NULL(manager); - // Check whether the node has only one output node. - if (manager->node_users().find(cnode) == manager->node_users().end()) { - MS_LOG(EXCEPTION) << "The node should be used by at least another node's input"; - } - if (manager->node_users()[cnode].size() > 1) { - return nullptr; - } CheckCNodeInputSize(cnode, kSingleInputIndex + 1); return cnode->input(kSingleInputIndex); } @@ -63,7 +53,7 @@ bool ReplaceMakeTuple(const FuncGraphPtr &func_graph, const CNodePtr &cnode) { std::vector new_make_tuple_inputs; bool need_update = false; for (const auto &input : cnode->inputs()) { - AnfNodePtr replace_input = GetReplaceNode(func_graph, input); + AnfNodePtr replace_input = GetReplaceNode(input); // If replace input is not null, it will be the input of the TransData or Cast. if (replace_input == nullptr) { new_make_tuple_inputs.push_back(input); @@ -119,7 +109,7 @@ const AnfNodePtr OptimizeDependence::Process(const FuncGraphPtr &func_graph, con if (ReplaceMakeTuple(func_graph, replacing_cnode)) { return nullptr; } - AnfNodePtr replace_node = GetReplaceNode(func_graph, replacing_cnode); + AnfNodePtr replace_node = GetReplaceNode(replacing_cnode); if (replace_node == nullptr) { MS_LOG(DEBUG) << "Can not find the TransData or Cast with single output node. Depend node: " << node->DebugString(); return nullptr; From 92ab989a85132da12d2dce6c0dbd102b825de3f7 Mon Sep 17 00:00:00 2001 From: zhangz0911gm Date: Tue, 14 Apr 2020 21:58:04 -0400 Subject: [PATCH 081/142] Getting Some simple issues fixed --- mindspore/ops/operations/math_ops.py | 4 ++-- mindspore/ops/operations/other_ops.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 78d813b9cc..33351a3ca1 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -773,8 +773,8 @@ class Mul(_MathBinaryOp): Tensor, the shape is same as the shape after broadcasting, and the data type is same as 'input_x'. Examples: - >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.int32) - >>> input_y = Tensor(np.array([4, 5, 6]), mindspore.int32) + >>> input_x = Tensor(np.array([1.0, 2.0, 3.0]), mindspore.float32) + >>> input_y = Tensor(np.array([4.0, 5.0, 6.0]), mindspore.float32) >>> mul = P.Mul() >>> mul(input_x, input_y) [4, 10, 18] diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 2ece6b7088..003395e9d9 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -209,8 +209,8 @@ class IOU(PrimitiveWithInfer): Examples: >>> iou = P.IOU() - >>> anchor_boxes = Tensor(np.random.randint(1,5, [10, 4])) - >>> gt_boxes = Tensor(np.random.randint(1,5, [3, 4])) + >>> anchor_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float32) + >>> gt_boxes = Tensor(np.random.randint(1.0, 5.0, [3, 4]), mindspore.float32) >>> iou(anchor_boxes, gt_boxes) """ From f24065806d16574392c290dc62a05b8234883719 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 22:52:36 -0400 Subject: [PATCH 082/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 28826c88bb..c7841b804b 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -255,9 +255,26 @@ class AvgPool1d(_PoolNd): Examples: >>> pool = nn.AvgPool1d(kernel_size=3, strides=1) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) + [[[[8. 8. 7. 4.] + [8. 4. 0. 9.] + [6. 4. 6. 1.] + [6. 8. 8. 5.] + [[4. 8. 5. 4.] + [8. 4. 7. 5.] + [3. 5. 3. 9.] + [7. 5. 4. 7.]]]] >>> output = pool(x) >>> output.shape() + (1, 2, 4, 2) >>> output + [[[[7.6640625 6.3320312] + [4. 4.3320312] + [5.3320312 3.6660156] + [7.3320312 7 ] + [[5.6640625 5.6640625] + [6.3320312 5.3320312] + [3.6660156 5.6640625] + [5.3320312 5.3320312]]]] """ def __init__(self, From beca38a3993baf0c137500d7f261c610f22f8b0a Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 22:57:18 -0400 Subject: [PATCH 083/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index c7841b804b..9880763041 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -258,7 +258,7 @@ class AvgPool1d(_PoolNd): [[[[8. 8. 7. 4.] [8. 4. 0. 9.] [6. 4. 6. 1.] - [6. 8. 8. 5.] + [6. 8. 8. 5.]] [[4. 8. 5. 4.] [8. 4. 7. 5.] [3. 5. 3. 9.] @@ -270,7 +270,7 @@ class AvgPool1d(_PoolNd): [[[[7.6640625 6.3320312] [4. 4.3320312] [5.3320312 3.6660156] - [7.3320312 7 ] + [7.3320312 7 ]] [[5.6640625 5.6640625] [6.3320312 5.3320312] [3.6660156 5.6640625] From 5e7e34f54af3c90815109d043f65a8db302130ee Mon Sep 17 00:00:00 2001 From: jjfeing Date: Wed, 22 Apr 2020 11:02:39 +0800 Subject: [PATCH 084/142] support buffer fusion --- .../ccsrc/kernel/tbe/tbe_kernel_build.cc | 112 ++++++++++++------ mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h | 13 +- 2 files changed, 87 insertions(+), 38 deletions(-) diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index 496f99df1c..7a521eb1cd 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -513,36 +513,36 @@ bool TbeKernelBuild::GenFusionScopeJson(const vector &inp return true; } -void TbeKernelBuild::GenDescJson(const shared_ptr &anf_node, size_t out_idx, - nlohmann::json *output_desc) { +void TbeKernelBuild::GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc) { std::string output_desc_name = anf_node->fullname_with_scope(); - if (out_idx > 0) { - output_desc_name = output_desc_name + "_" + std::to_string(out_idx); + if (node_out_idx > 0) { + output_desc_name = output_desc_name + "_" + std::to_string(node_out_idx); } (*output_desc)["name"] = NormalizeFullScopeName(output_desc_name); - auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, out_idx); + auto type_id = AnfAlgo::GetOutputDeviceDataType(anf_node, node_out_idx); (*output_desc)["data_type"] = tbe::TypeIdToString(type_id); - auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, out_idx); + auto ori_shape = AnfAlgo::GetOutputInferShape(anf_node, node_out_idx); if (ori_shape.empty()) { ori_shape.emplace_back(1); } (*output_desc)["ori_shape"] = ori_shape; - auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, out_idx); + auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, node_out_idx); if (shape.empty()) { shape.emplace_back(1); } (*output_desc)["shape"] = shape; - auto format = AnfAlgo::GetOutputFormat(anf_node, out_idx); + auto format = AnfAlgo::GetOutputFormat(anf_node, node_out_idx); if (format == kOpFormat_DEFAULT) { if (ori_shape.size() == 4) { format = kOpFormat_NCHW; } else { - format = "ND"; + format = kOpFormat_ND; } } (*output_desc)["format"] = format; (*output_desc)["ori_format"] = kOpFormat_NCHW; - (*output_desc)["output_index"] = out_idx; + (*output_desc)["output_index"] = desc_output_idx; } void TbeKernelBuild::GenReusedOutputDesc(const shared_ptr &anf_node, size_t index, @@ -605,7 +605,7 @@ bool TbeKernelBuild::GenFusionDataInputJson(const shared_ptr MS_LOG(INFO) << "real name " << real_node->fullname_with_scope() << " index:" << real_idx; // "output_desc" nlohmann::json output_desc; - GenDescJson(real_node, real_idx, &output_desc); + GenDescJson(real_node, real_idx, real_idx, &output_desc); output_desc_list.push_back(output_desc); (*data_str)["name"] = NormalizeFullScopeName(real_node->fullname_with_scope()); } @@ -653,9 +653,9 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i return (op_info->inputs_ptr().size() + 1 - cnode->inputs().size()); } -bool TbeKernelBuild::GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index) { +bool TbeKernelBuild::GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index) { MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(input_desc_list); bool is_dynamic_input = IsDynamicInput(cnode); @@ -666,7 +666,7 @@ bool TbeKernelBuild::GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode size_t real_idx = kernel_idx.second; MS_LOG(INFO) << "real name" << real_node->fullname_with_scope() << "index:" << real_idx; nlohmann::json input_desc; - GenDescJson(real_node, real_idx, &input_desc); + GenDescJson(real_node, real_idx, real_idx, &input_desc); if (is_dynamic_input) { MS_LOG(INFO) << "node has dynamic input."; input_desc["dyn_index"] = (i - 1); @@ -687,6 +687,66 @@ bool TbeKernelBuild::GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode return true; } +std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector &output_used_nums) { + std::vector desc_output_index = {}; + bool find_reused = false; + size_t reused_num = 0; + for (size_t idx = 0; idx < output_used_nums.size(); ++idx) { + auto output_use_num_item = output_used_nums[idx]; + MS_LOG(INFO) << "output used num[" << idx << "] = " << output_use_num_item; + if (output_use_num_item == 1 || output_use_num_item == 0) { + desc_output_index.emplace_back(idx); + } else { + if (!find_reused) { + desc_output_index.emplace_back(idx); + } else { + desc_output_index.emplace_back(output_used_nums[idx - 1]); + } + reused_num += (output_use_num_item - 1); + find_reused = true; + } + } + auto pad_value = output_used_nums.size() == 1 ? 0 : desc_output_index[desc_output_index.size() - 1] + 1; + for (size_t i = 0; i < reused_num; ++i) { + desc_output_index.emplace_back(pad_value); + } + return desc_output_index; +} + +bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list) { + auto output_size = AnfAlgo::GetOutputTensorNum(cnode); + if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { + auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); + MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope(); + if (output_used_nums.size() != output_size) { + MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")" + << " is not match output used num(" << output_used_nums.size() << ")"; + return false; + } + auto desc_output_index = GetDescOutputIndex(output_used_nums); + for (size_t i = 0; i < output_size; ++i) { + MS_LOG(INFO) << "Fusion index: " << i << ", desc_output_index: " << desc_output_index[i]; + nlohmann::json output_desc; + GenDescJson(cnode, i, desc_output_index[i], &output_desc); + output_desc_list->emplace_back(output_desc); + } + for (size_t j = output_size; j < desc_output_index.size(); ++j) { + MS_LOG(INFO) << "Fusion index: " << j << ", desc_output_index: " << desc_output_index[j]; + nlohmann::json output_desc; + GenReusedOutputDesc(cnode, j, desc_output_index[j], &output_desc); + output_desc_list->emplace_back(output_desc); + } + } else { + for (size_t i = 0; i < output_size; ++i) { + nlohmann::json output_desc; + GenDescJson(cnode, i, i, &output_desc); + output_desc_list->push_back(output_desc); + } + } + return true; +} + bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, std::vector>::iterator *layer_iter, nlohmann::json *compute_op_str, std::string *fusion_kernel_name, @@ -696,28 +756,14 @@ bool TbeKernelBuild::GenFusionComputeJson(const mindspore::AnfNodePtr &compute_n MS_EXCEPTION_IF_NULL(cnode); // gen input desc std::vector input_desc_list; - (void)GenFusionComputeInputeJson(cnode, layer_iter, &input_desc_list, index); + (void)GenFusionComputeInputJson(cnode, layer_iter, &input_desc_list, index); (*compute_op_str)["input_desc"] = input_desc_list; // gen output desc std::vector output_desc_list; - auto output_size = AnfAlgo::GetOutputTensorNum(cnode); - for (size_t i = 0; i < output_size; ++i) { - nlohmann::json output_desc; - GenDescJson(cnode, i, &output_desc); - output_desc_list.push_back(output_desc); - } - - if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimConv2D->name()) { - if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, compute_node)) { - auto output_used_num = AnfAlgo::GetNodeAttr(compute_node, kAttrOutputUsedNum); - for (size_t i = output_size; i < output_used_num; ++i) { - nlohmann::json output_desc; - GenReusedOutputDesc(cnode, i, 0, &output_desc); - output_desc_list.push_back(output_desc); - } - } + if (!GenFusionComputeOutputJson(cnode, &output_desc_list)) { + MS_LOG(INFO) << "Fusion Error: gen fusion output desc faild, node full name: " << cnode->fullname_with_scope(); + return false; } - (*compute_op_str)["output_desc"] = output_desc_list; // gen others auto type = AnfAlgo::GetCNodeName(cnode); diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h index de5ed84e41..1a3eee7fd9 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.h @@ -53,11 +53,14 @@ class TbeKernelBuild { static bool GenFusionComputeJson(const mindspore::AnfNodePtr &compute_node, std::vector>::iterator *layer_iter, nlohmann::json *compute_op_str, std::string *fusion_kernel_name, size_t *index); - static bool GenFusionComputeInputeJson(const mindspore::CNodePtr &cnode, - std::vector>::iterator *layer_iter, - std::vector *input_desc_list, size_t *index); - static void GenDescJson(const std::shared_ptr &anf_node, size_t out_idx, - nlohmann::json *output_desc); + static bool GenFusionComputeInputJson(const mindspore::CNodePtr &cnode, + std::vector>::iterator *layer_iter, + std::vector *input_desc_list, size_t *index); + static std::vector GetDescOutputIndex(const std::vector &output_used_nums); + static bool GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode, + std::vector *output_desc_list); + static void GenDescJson(const std::shared_ptr &anf_node, size_t node_out_idx, + size_t desc_output_idx, nlohmann::json *output_desc); static void GenReusedOutputDesc(const std::shared_ptr &anf_node, size_t index, size_t output_index, nlohmann::json *output_desc); static size_t GetIOSizeImpl(const nlohmann::json &desc); From 5105e951600527fb045d1e13681846ba9dcb0748 Mon Sep 17 00:00:00 2001 From: wukesong Date: Tue, 21 Apr 2020 17:36:11 +0800 Subject: [PATCH 085/142] add lenet&alexnet readme in example --- example/alexnet_cifar10/README.md | 58 ++++++++++++++++++++++++++++ example/lenet_mnist/README.md | 63 +++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+) create mode 100644 example/alexnet_cifar10/README.md create mode 100644 example/lenet_mnist/README.md diff --git a/example/alexnet_cifar10/README.md b/example/alexnet_cifar10/README.md new file mode 100644 index 0000000000..0efd3ca1bf --- /dev/null +++ b/example/alexnet_cifar10/README.md @@ -0,0 +1,58 @@ +# AlexNet Example + +## Description + +Training AlexNet with CIFAR-10 dataset in MindSpore. + +This is the simple tutorial for training AlexNet in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the CIFAR-10 dataset at . The directory structure is as follows: + +``` +├─cifar-10-batches-bin +│ +└─cifar-10-verify-bin +``` + +## Running the example + +```python +# train AlexNet, hyperparameter setting in config.py +python train.py --data_path cifar-10-batches-bin +``` + +You can get loss with each step similar to this: + +```bash +epoch: 1 step: 1, loss is 2.2791853 +... +epoch: 1 step: 1536, loss is 1.9366643 +epoch: 1 step: 1537, loss is 1.6983616 +epoch: 1 step: 1538, loss is 1.0221305 +... +``` + +Then, test AlexNet according to network model +```python +# test AlexNet, 1 epoch training accuracy is up to 51.1%; 10 epoch training accuracy is up to 81.2% +python eval.py --data_path cifar-10-verify-bin --mode test --ckpt_path checkpoint_alexnet-1_1562.ckpt +``` + +## Note +There are some optional arguments: + +```bash +-h, --help show this help message and exit +--device_target {Ascend,GPU} + device where the code will be implemented (default: Ascend) +--data_path DATA_PATH + path where the dataset is saved +--dataset_sink_mode DATASET_SINK_MODE + dataset_sink_mode is False or True +``` + +You can run ```python train.py -h``` or ```python eval.py -h``` to get more information. diff --git a/example/lenet_mnist/README.md b/example/lenet_mnist/README.md new file mode 100644 index 0000000000..fea92883c6 --- /dev/null +++ b/example/lenet_mnist/README.md @@ -0,0 +1,63 @@ +# LeNet Example + +## Description + +Training LeNet with MNIST dataset in MindSpore. + +This is the simple and basic tutorial for constructing a network in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Download the MNIST dataset at . The directory structure is as follows: + +``` +└─MNIST_Data + ├─test + │ t10k-images.idx3-ubyte + │ t10k-labels.idx1-ubyte + │ + └─train + train-images.idx3-ubyte + train-labels.idx1-ubyte +``` + +## Running the example + +```python +# train LeNet, hyperparameter setting in config.py +python train.py --data_path MNIST_Data +``` + +You can get loss with each step similar to this: + +```bash +epoch: 1 step: 1, loss is 2.3040335 +... +epoch: 1 step: 1739, loss is 0.06952668 +epoch: 1 step: 1740, loss is 0.05038793 +epoch: 1 step: 1741, loss is 0.05018193 +... +``` + +Then, test LeNet according to network model +```python +# test LeNet, after 1 epoch training, the accuracy is up to 96.5% +python eval.py --data_path MNIST_Data --mode test --ckpt_path checkpoint_lenet-1_1875.ckpt +``` + +## Note +There are some optional arguments: + +```bash +-h, --help show this help message and exit +--device_target {Ascend,GPU,CPU} + device where the code will be implemented (default: Ascend) +--data_path DATA_PATH + path where the dataset is saved +--dataset_sink_mode DATASET_SINK_MODE + dataset_sink_mode is False or True +``` + +You can run ```python train.py -h``` or ```python eval.py -h``` to get more information. From b5b506c4c3d5005935bc8723da68b49729dd492f Mon Sep 17 00:00:00 2001 From: xulei2020 <“xulei83@huawei.com”> Date: Wed, 22 Apr 2020 11:10:15 +0800 Subject: [PATCH 086/142] add code --- mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc | 4 ++-- mindspore/ccsrc/dataset/engine/datasetops/filter_op.h | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc index e6662dea0f..ce312ce3d9 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.cc @@ -75,7 +75,7 @@ Status FilterOp::EoeReceived(int32_t) { return Status::OK(); } // Validating if each of the input_columns exists in the DataBuffer. Status FilterOp::ValidateInColumns(const std::unordered_map &col_name_id_map, - std::vector *input_columns) { + const std::vector *input_columns) { for (const auto &inCol : *input_columns) { bool found = col_name_id_map.find(inCol) != col_name_id_map.end() ? true : false; if (!found) { @@ -202,7 +202,7 @@ Status FilterOp::Collector() { } // Private function for checking the column legality. -Status FilterOp::CheckColumns(const DataBuffer *in_buf, std::vector *input_columns) { +Status FilterOp::CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns) { int32_t num_rows = in_buf->NumRows(); int32_t num_cols = in_buf->NumCols(); if (num_rows == 0 || num_cols == 0) { diff --git a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h index b182bf8ce6..92312e0843 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/filter_op.h @@ -100,6 +100,9 @@ class FilterOp : public ParallelOp { FilterOp(const std::vector &in_col_names, int32_t num_workers, int32_t op_queue_size, py::function predicate_func); + // Destructor + ~FilterOp() = default; + // Class functor operator () override. // All dataset ops operate by launching a thread (see ExecutionTree),This class functor will // provide the master loop that drives the logic for performing the work. @@ -163,14 +166,14 @@ class FilterOp : public ParallelOp { // @param input_columns The vector of input column names used in the current thread. // @return Status The error code return. Status ValidateInColumns(const std::unordered_map &col_name_id_map, - std::vector *input_columns); + const std::vector *input_columns); // Private function for checking the column legality // @param in_buf A raw pointer to the DataBuffer. A raw pointer is fine because this function does not manage memory // and is not shared with other threads. // @param[out] to_process_indices Indices of columns that will feed to predicate. // @param input_columns The vector of input column names used in the current thread. - Status CheckColumns(const DataBuffer *in_buf, std::vector *input_columns); + Status CheckColumns(const DataBuffer *in_buf, const std::vector *input_columns); }; } // namespace dataset From ac2d5df2a1a69ac55c4309a3138c519b49ac8d79 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Mon, 20 Apr 2020 16:44:47 +0800 Subject: [PATCH 087/142] add dtype trans template --- mindspore/ccsrc/common/trans.cc | 44 ++++++++++++++----- mindspore/ccsrc/common/trans.h | 9 ++-- .../device/ascend/ascend_device_address.cc | 32 +++++++------- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index b4e02c8fe6..1174be1f48 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -103,17 +103,39 @@ const std::map, DataTypeTransMode> mode_map{ template void TransDataSrc2Dst(const TypeIdArgs &args, void *dst, const size_t data_size) { + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(EXCEPTION) << "Invalid src or dst data size."; + } for (size_t idx = 0; idx != data_size; idx++) { SrcT src_data = static_cast(args.data)[idx]; static_cast(dst)[idx] = static_cast(src_data); } } +template +void TransDataSrc2Fp16(const TypeIdArgs &args, void *dst, const size_t data_size) { + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(EXCEPTION) << "Invalid src or dst data size."; + } + auto src_data = static_cast(args.data); + auto half_data = static_cast(dst); + for (size_t i = 0; i < data_size; i++) { + half_data[i] = Eigen::half(src_data[i]); + } +} + bool CastKernel(const TypeIdArgs &args, void *dst, const size_t data_size, const DataTypeTransMode mode) { switch (mode) { case FROM_FLOAT_TO_FLOAT16: device::FloatToHalf(dst, args.data, data_size); break; + case FROM_INT32_TO_FLOAT16: + TransDataSrc2Fp16(args, dst, data_size); + break; case FROM_FLOAT16_TO_FLOAT: device::HalfToFloat(dst, args.data, data_size); break; @@ -372,27 +394,27 @@ bool CheckArgs(const FormatArgs &args, size_t *size, size_t *total_size) { } bool TransDataType(const TypeIdArgs &args, void *result) { - MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.host_data_type) << " to " - << TypeIdLabel(args.device_data_type); + MS_LOG(DEBUG) << "Begin trans datatype from " << TypeIdLabel(args.src_type) << " to " << TypeIdLabel(args.dst_type); MS_EXCEPTION_IF_NULL(result); - std::pair type_info(args.host_data_type, args.device_data_type); + std::pair type_info(args.src_type, args.dst_type); auto iter = mode_map.find(type_info); if (iter == mode_map.end()) { - MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.host_data_type) - << ", dst_type:" << TypeIdLabel(args.device_data_type); + MS_LOG(ERROR) << "Unsupported datatype trans. src_type :" << TypeIdLabel(args.src_type) + << ", dst_type:" << TypeIdLabel(args.dst_type); return false; } auto trans_mode = iter->second; - auto type_size = TypeIdSize(args.device_data_type); - if (type_size < 1) { - MS_LOG(ERROR) << "Invalid host data type."; + auto src_id = TypeIdSize(args.src_type); + auto dst_id = TypeIdSize(args.dst_type); + if (src_id < 1 || dst_id < 1) { + MS_LOG(ERROR) << "Invalid src or dst data type."; return false; } - if (args.host_shape_size < 1) { - MS_LOG(ERROR) << "Invalid host data size."; + if (args.src_size / src_id != args.src_shape_size || args.dst_size / dst_id != args.dst_shape_size) { + MS_LOG(ERROR) << "Invalid src or dst data size."; return false; } - if (!CastKernel(args, result, args.host_shape_size, trans_mode)) { + if (!CastKernel(args, result, args.dst_shape_size, trans_mode)) { MS_LOG(ERROR) << "Failed to trans datatype.."; return false; } diff --git a/mindspore/ccsrc/common/trans.h b/mindspore/ccsrc/common/trans.h index 054fa89a06..e6e81ed359 100644 --- a/mindspore/ccsrc/common/trans.h +++ b/mindspore/ccsrc/common/trans.h @@ -31,9 +31,12 @@ namespace mindspore { namespace trans { struct TypeIdArgs { const void *data; - size_t host_shape_size; // Multiply each dimension elements. [a, b, c, d] => a*b*c*d - TypeId host_data_type; - TypeId device_data_type; + size_t src_size; + size_t dst_size; + TypeId src_type; + TypeId dst_type; + size_t src_shape_size; + size_t dst_shape_size; }; struct FormatArgs { diff --git a/mindspore/ccsrc/device/ascend/ascend_device_address.cc b/mindspore/ccsrc/device/ascend/ascend_device_address.cc index 79241df612..df49400341 100644 --- a/mindspore/ccsrc/device/ascend/ascend_device_address.cc +++ b/mindspore/ccsrc/device/ascend/ascend_device_address.cc @@ -104,10 +104,10 @@ bool AscendDeviceAddress::SyncDeviceToHost(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = SyncDeviceToHostAndFloatToFloat64(host_ptr, size, ptr_, size_); } else { - auto shape_size = trans::ShapeSize(host_shape); + auto host_size = trans::ShapeSize(host_shape); auto host = std::vector(size_); SyncMemory(host.data(), ptr_, size_, RT_MEMCPY_DEVICE_TO_HOST); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; + const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, host_size, host_size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { MS_LOG(ERROR) << "trans data type failed."; @@ -153,14 +153,15 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector(size_); sync_ok = trans::TransFormatFromDeviceToHost(format_args, host.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host.data(), shape_size, type_id_, type}; + auto host_size = trans::ShapeSize(host_shape); + auto device_size = trans::ShapeSize(device_shape); + const trans::TypeIdArgs type_args{host.data(), size_, size, type_id_, type, device_size, host_size}; sync_ok = trans::TransDataType(type_args, host_ptr); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } } else { @@ -168,7 +169,7 @@ bool AscendDeviceAddress::SyncDeviceToHostAndConvertFormat(const std::vector &shape, size_t } else if (type_id_ == kNumberTypeFloat32 && type == kNumberTypeFloat64) { sync_ok = Float64ToFloatAndSyncHostToDevice(ptr_, size_, host_ptr, size); } else { - auto shape_size = trans::ShapeSize(host_shape); - const trans::TypeIdArgs type_args{host_ptr, shape_size, type, type_id_}; + auto host_size = trans::ShapeSize(host_shape); + const trans::TypeIdArgs type_args{host_ptr, size, size_, type, type_id_, host_size, host_size}; auto host_tmp = std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans data type failed."; + MS_LOG(ERROR) << "Trans data type failed."; return false; } SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); @@ -234,12 +235,13 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransDataType(type_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans datatype failed."; + MS_LOG(ERROR) << "Trans datatype failed."; return false; } const trans::FormatArgs format_args{host_tmp.data(), size_, kOpFormat_NCHW, format_, @@ -247,7 +249,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransFormat(format_args, dst_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } SyncMemory(ptr_, dst_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); @@ -256,7 +258,7 @@ bool AscendDeviceAddress::ConvertFormatAndSyncHostToDevice(const std::vector(size_); sync_ok = trans::TransFormat(format_args, host_tmp.data()); if (!sync_ok) { - MS_LOG(ERROR) << "trans format failed."; + MS_LOG(ERROR) << "Trans format failed."; return false; } SyncMemory(ptr_, host_tmp.data(), size_, RT_MEMCPY_HOST_TO_DEVICE); From 06ca226da4976ba413b1bb09ceb02003bc828d47 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 23:50:45 -0400 Subject: [PATCH 088/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 9880763041..84ec2414a3 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -255,14 +255,6 @@ class AvgPool1d(_PoolNd): Examples: >>> pool = nn.AvgPool1d(kernel_size=3, strides=1) >>> x = Tensor(np.random.randint(0, 10, [1, 2, 4, 4]), mindspore.float32) - [[[[8. 8. 7. 4.] - [8. 4. 0. 9.] - [6. 4. 6. 1.] - [6. 8. 8. 5.]] - [[4. 8. 5. 4.] - [8. 4. 7. 5.] - [3. 5. 3. 9.] - [7. 5. 4. 7.]]]] >>> output = pool(x) >>> output.shape() (1, 2, 4, 2) From faf95e40ebf6a6c8aca1481844016e7f0934a370 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Tue, 21 Apr 2020 23:52:06 -0400 Subject: [PATCH 089/142] add AvgPooling layer --- mindspore/nn/layer/pooling.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mindspore/nn/layer/pooling.py b/mindspore/nn/layer/pooling.py index 84ec2414a3..6cf06de029 100644 --- a/mindspore/nn/layer/pooling.py +++ b/mindspore/nn/layer/pooling.py @@ -258,15 +258,6 @@ class AvgPool1d(_PoolNd): >>> output = pool(x) >>> output.shape() (1, 2, 4, 2) - >>> output - [[[[7.6640625 6.3320312] - [4. 4.3320312] - [5.3320312 3.6660156] - [7.3320312 7 ]] - [[5.6640625 5.6640625] - [6.3320312 5.3320312] - [3.6660156 5.6640625] - [5.3320312 5.3320312]]]] """ def __init__(self, From b8d7cd9775cdf90a55f5fbf216f0f78628ac8e6d Mon Sep 17 00:00:00 2001 From: VectorSL Date: Wed, 22 Apr 2020 12:46:56 +0800 Subject: [PATCH 090/142] gpu change compute capacity strategy --- mindspore/ccsrc/device/gpu/cuda_common.h | 3 ++- mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc | 10 +++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mindspore/ccsrc/device/gpu/cuda_common.h b/mindspore/ccsrc/device/gpu/cuda_common.h index 5a5b6416ce..b79ba8bc28 100644 --- a/mindspore/ccsrc/device/gpu/cuda_common.h +++ b/mindspore/ccsrc/device/gpu/cuda_common.h @@ -56,7 +56,8 @@ class CudaCommon { #define GET_BLOCKS(total_threads) mindspore::device::gpu::CudaCommon::GetInstance().blocks_num(total_threads) #define GET_THREADS mindspore::device::gpu::CudaCommon::GetInstance().threads_num() #define GET_MAJOR_SM mindspore::device::gpu::CudaCommon::GetInstance().major_sm() -#define MINIUM_SM 7 +#define MINIUM_SM 6 +#define RECOMMEND_SM 7 } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc index fba2b24512..e38cc02e23 100644 --- a/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc @@ -96,9 +96,13 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & bool flag = true; // data type matching check of all input parameters of kernel for (size_t input_index = 0; input_index < kernel_info->GetInputNum(); input_index++) { - if (marjor_sm < MINIUM_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { - MS_LOG(EXCEPTION) << "Half precision op can be used on Devices which compute capacity is above " << MINIUM_SM - << ", but your device's compute capacity is " << marjor_sm; + if (marjor_sm < RECOMMEND_SM && kernel_info->GetInputDeviceType(input_index) == kNumberTypeFloat16) { + if (marjor_sm < MINIUM_SM) { + MS_LOG(EXCEPTION) << "Half precision ops can be used on Devices which computing capacity is >= " << MINIUM_SM + << ", but the current device's computing capacity is " << marjor_sm; + } + MS_LOG(WARNING) << "It is recommended to use devices with a computing capacity >= " << RECOMMEND_SM + << ", but the current device's computing capacity is " << marjor_sm; } if (kernel_info->GetInputDeviceType(input_index) != (iter->second)[attr_index].first.GetInputAttr(input_index).first) { From 437bb8c27c79cdb7352f30a967c0f5ee68d46d56 Mon Sep 17 00:00:00 2001 From: buxue Date: Tue, 21 Apr 2020 15:39:59 +0800 Subject: [PATCH 091/142] support ellipsis and bool for tensor slice --- mindspore/ccsrc/ir/dtype.cc | 2 + mindspore/ccsrc/ir/dtype/empty.cc | 1 - mindspore/ccsrc/ir/dtype/empty.h | 14 +++++- mindspore/ccsrc/ir/dtype/type.h | 1 + mindspore/ccsrc/ir/named.cc | 5 +- mindspore/ccsrc/ir/named.h | 12 +++-- .../ccsrc/operator/cc_implementations.cc | 6 +-- .../ccsrc/operator/composite/composite.cc | 50 +++++++++++++++---- .../ccsrc/operator/composite/composite.h | 2 + mindspore/ccsrc/pipeline/parse/parse.cc | 8 ++- mindspore/ccsrc/pipeline/parse/parse.h | 2 + .../static_analysis/abstract_value.cc | 25 ++++++++-- .../pipeline/static_analysis/abstract_value.h | 16 +++++- .../composite/multitype_ops/getitem_impl.py | 19 ++++++- mindspore/ops/functional.py | 1 + tests/ut/python/ops/test_tensor_slice.py | 39 +++++++++++---- .../ut/python/pipeline/parse/test_operator.py | 7 +-- 17 files changed, 170 insertions(+), 40 deletions(-) diff --git a/mindspore/ccsrc/ir/dtype.cc b/mindspore/ccsrc/ir/dtype.cc index a6ef99177c..97291a3dc0 100644 --- a/mindspore/ccsrc/ir/dtype.cc +++ b/mindspore/ccsrc/ir/dtype.cc @@ -495,6 +495,8 @@ TypePtr StringToType(const std::string &type_name) { TypePtr type = nullptr; if (type_name.compare("None") == 0) { type = std::make_shared(); + } else if (type_name.compare("Ellipsis") == 0) { + type = std::make_shared(); } else if (type_name.compare("TypeType") == 0) { type = std::make_shared(); } else if (type_name.compare("SymbolicKeyType") == 0) { diff --git a/mindspore/ccsrc/ir/dtype/empty.cc b/mindspore/ccsrc/ir/dtype/empty.cc index 3d4f74bf31..5cb3a91806 100644 --- a/mindspore/ccsrc/ir/dtype/empty.cc +++ b/mindspore/ccsrc/ir/dtype/empty.cc @@ -18,6 +18,5 @@ namespace mindspore { const TypePtr kTypeNone = std::make_shared(); -const TypePtr kTypeAnything = std::make_shared(); const TypePtr kAnyType = std::make_shared(); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype/empty.h b/mindspore/ccsrc/ir/dtype/empty.h index a13dc084ca..76cf8ea0eb 100644 --- a/mindspore/ccsrc/ir/dtype/empty.h +++ b/mindspore/ccsrc/ir/dtype/empty.h @@ -71,8 +71,20 @@ class TypeNull : public Type { }; using TypeNullPtr = std::shared_ptr; +class Ellipsis : public Type { + public: + Ellipsis() : Type(kMetaTypeEllipsis) {} + ~Ellipsis() override {} + MS_DECLARE_PARENT(Ellipsis, Type) + + TypeId generic_type_id() const override { return kMetaTypeEllipsis; } + TypePtr DeepCopy() const override { return std::make_shared(); } + std::string ToReprString() const override { return "Ellipsis"; } + std::string DumpText() const override { return "Ellipsis"; } +}; +using EllipsisPtr = std::shared_ptr; + extern const TypePtr kTypeNone; -extern const TypePtr kTypeAnything; extern const TypePtr kAnyType; } // namespace mindspore diff --git a/mindspore/ccsrc/ir/dtype/type.h b/mindspore/ccsrc/ir/dtype/type.h index 0528bccf03..1c67b6a855 100644 --- a/mindspore/ccsrc/ir/dtype/type.h +++ b/mindspore/ccsrc/ir/dtype/type.h @@ -49,6 +49,7 @@ enum TypeId : int { kMetaTypeExternal, kMetaTypeNone, kMetaTypeNull, + kMetaTypeEllipsis, kMetaTypeEnd, // // Object types diff --git a/mindspore/ccsrc/ir/named.cc b/mindspore/ccsrc/ir/named.cc index 67e11c64d3..0a679e6011 100644 --- a/mindspore/ccsrc/ir/named.cc +++ b/mindspore/ccsrc/ir/named.cc @@ -31,5 +31,8 @@ abstract::AbstractBasePtr None::ToAbstract() { return std::make_shared(); abstract::AbstractBasePtr NullObj::ToAbstract() { return std::make_shared(); } -const NamedPtr kNullObj = std::make_shared(); +const NamedPtr kNull = std::make_shared(); + +abstract::AbstractBasePtr EllipsisObj::ToAbstract() { return std::make_shared(); } +const NamedPtr kEllipsis = std::make_shared(); } // namespace mindspore diff --git a/mindspore/ccsrc/ir/named.h b/mindspore/ccsrc/ir/named.h index 76136fb298..2d679c58b1 100644 --- a/mindspore/ccsrc/ir/named.h +++ b/mindspore/ccsrc/ir/named.h @@ -61,7 +61,6 @@ class Named : public Value { std::string name_; std::size_t hash_id_; }; - using NamedPtr = std::shared_ptr; class None : public Named { @@ -71,7 +70,6 @@ class None : public Named { MS_DECLARE_PARENT(None, Named); abstract::AbstractBasePtr ToAbstract() override; }; - extern const NamedPtr kNone; class NullObj : public Named { @@ -81,7 +79,15 @@ class NullObj : public Named { MS_DECLARE_PARENT(NullObj, Named); abstract::AbstractBasePtr ToAbstract() override; }; +extern const NamedPtr kNull; -extern const NamedPtr kNullObj; +class EllipsisObj : public Named { + public: + EllipsisObj() : Named("Ellipsis") {} + ~EllipsisObj() override = default; + MS_DECLARE_PARENT(EllipsisObj, Named); + abstract::AbstractBasePtr ToAbstract() override; +}; +extern const NamedPtr kEllipsis; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_NAMED_H_ diff --git a/mindspore/ccsrc/operator/cc_implementations.cc b/mindspore/ccsrc/operator/cc_implementations.cc index 2a3429ca52..52b71f410f 100644 --- a/mindspore/ccsrc/operator/cc_implementations.cc +++ b/mindspore/ccsrc/operator/cc_implementations.cc @@ -135,9 +135,9 @@ T InnerScalarMod(T x, T y) { if (std::is_integral::value) { return static_cast(x) % static_cast(y); } - float x_int = std::floor(x); - float y_int = std::ceil(y); - float max = x_int / y_int; + int x_int = std::floor(x); + int y_int = std::ceil(y); + int max = x_int / y_int; float ret = x - y * max; return ret; } diff --git a/mindspore/ccsrc/operator/composite/composite.cc b/mindspore/ccsrc/operator/composite/composite.cc index bf0dcf37d4..11ab31a292 100644 --- a/mindspore/ccsrc/operator/composite/composite.cc +++ b/mindspore/ccsrc/operator/composite/composite.cc @@ -46,6 +46,8 @@ using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractClass; using mindspore::abstract::AbstractDictionary; using mindspore::abstract::AbstractDictionaryPtr; +using mindspore::abstract::AbstractEllipsis; +using mindspore::abstract::AbstractEllipsisPtr; using mindspore::abstract::AbstractFunction; using mindspore::abstract::AbstractFunctionPtr; using mindspore::abstract::AbstractList; @@ -1081,6 +1083,7 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, std::vector shrink; auto slice_tuple_eles = slice_tuple->elements(); + size_t ellipsis_num = 0; for (size_t index = 0; index < slice_tuple_size; index++) { if (slice_tuple_eles[index]->isa()) { AbstractSlicePtr slice = dyn_cast(slice_tuple_eles[index]); @@ -1098,7 +1101,20 @@ int GenerateStridedSliceParametersFromTuple(const AbstractTuplePtr &slice_tuple, continue; } - MS_LOG(EXCEPTION) << "Slice tuple only could contain slice or int number, but got " + if (slice_tuple_eles[index]->isa()) { + ellipsis_num++; + if (ellipsis_num > 1) { + MS_LOG(EXCEPTION) << "Tensor slice supports at most one ellipsis"; + } + size_t ellipsis_len = shape_size - (slice_tuple_size - 1); + begin->insert(begin->end(), ellipsis_len, 0); + end->insert(end->end(), shape.begin() + index, shape.begin() + index + ellipsis_len); + strides->insert(strides->end(), ellipsis_len, 1); + shrink.insert(shrink.end(), ellipsis_len, 0); + continue; + } + + MS_LOG(EXCEPTION) << "Slice tuple only could contain slice, int number or ellipsis, but got " << slice_tuple_eles[index]->ToString(); } @@ -1160,6 +1176,11 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec abstract::CheckArgsSize(op_name, args_spec_list, 2); AbstractTensorPtr tensorPtr = abstract::CheckArg(op_name, args_spec_list, 0); + FuncGraphPtr ret_graph = std::make_shared(); + ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); + AnfNodePtr tensor_node = ret_graph->add_parameter(); + (void)ret_graph->add_parameter(); + auto shape = tensorPtr->shape()->shape(); std::vector begin; std::vector end; @@ -1174,23 +1195,28 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec shrink_axis_mask = GenerateStridedSliceParametersFromSlice(slice_ptr, shape, &begin, &end, &strides); } else if (args_spec_list[1]->isa()) { AbstractScalarPtr scalar_ptr = dyn_cast(args_spec_list[1]); + if (scalar_ptr->BuildValue()->isa()) { + if (scalar_ptr->BuildValue()->cast()->value()) { + return ExpandADim(ret_graph, tensor_node); + } + } shrink_axis_mask = GenerateStridedSliceParametersFromNumber(scalar_ptr, shape, &begin, &end, &strides); + } else if (args_spec_list[1]->isa()) { + ret_graph->set_output(tensor_node); + return ret_graph; + } else if (args_spec_list[1]->isa()) { + return ExpandADim(ret_graph, tensor_node); } else { std::ostringstream args_info; for (const auto &arg : args_spec_list) { MS_EXCEPTION_IF_NULL(arg); args_info << arg->ToString() << "\n"; } - MS_LOG(EXCEPTION) << "TensorSlice requires to input a tensor and a slice or slice tuple, but got " - << args_info.str(); + MS_LOG(EXCEPTION) + << "TensorSlice requires the input should be one of [slice, ellipsis, int number, bool, none, tuple] , but got " + << args_info.str(); } - FuncGraphPtr ret_graph = std::make_shared(); - ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); - - AnfNodePtr tensor_node = ret_graph->add_parameter(); - (void)ret_graph->add_parameter(); - auto PrimStridedSliceClass = prim::GetPythonOps("StridedSlice", "mindspore.ops.operations"); auto PrimStridedSlice = ret_graph->NewCNode({NewValueNode(PrimStridedSliceClass), NewValueNode(0), NewValueNode(0), NewValueNode(0), NewValueNode(0), NewValueNode(shrink_axis_mask)}); @@ -1199,6 +1225,12 @@ FuncGraphPtr TensorSlice::GenerateFuncGraph(const AbstractBasePtrList &args_spec return ret_graph; } +FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const { + auto PrimExpandDims = GetPythonOps("expand_dims", "mindspore.ops.functional"); + ret_graph->set_output(NewCNode({NewValueNode(PrimExpandDims), tensor_node, NewValueNode(0)}, ret_graph)); + return ret_graph; +} + REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { (void)py::class_>(*m, "TupleAdd_") .def(py::init()); diff --git a/mindspore/ccsrc/operator/composite/composite.h b/mindspore/ccsrc/operator/composite/composite.h index 1dad2e08cf..429cf5341a 100644 --- a/mindspore/ccsrc/operator/composite/composite.h +++ b/mindspore/ccsrc/operator/composite/composite.h @@ -206,6 +206,8 @@ class TensorSlice : public MetaFuncGraph { MS_DECLARE_PARENT(TensorSlice, MetaFuncGraph) FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; friend bool operator==(const TensorSlice &lhs, const TensorSlice &rhs) { return lhs.name_ == rhs.name_; } + + FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; }; using TensorSlicePtr = std::shared_ptr; diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index 51c4fc17ec..22d6fc9049 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -109,6 +109,7 @@ void Parser::BuildMethodMap() { expr_method_map_["Index"] = &Parser::ParseIndex; expr_method_map_["UnaryOp"] = &Parser::ParseUnaryOp; expr_method_map_["Dict"] = &Parser::ParseDict; + expr_method_map_["Ellipsis"] = &Parser::ParseEllipsis; } void Parser::UpdateTopFuncGraph(const FuncGraphPtr &func_graph) { top_func_graph_ = FuncGraphWeakPtr(func_graph); } @@ -187,7 +188,7 @@ void Parser::GenerateArgsDefaultValueForFunction(const FunctionBlockPtr &block, namelist_for_default_value.push_back(arg_name); if (py::isinstance(defaults[i])) { - default_values.push_back(NewValueNode(kNullObj)); + default_values.push_back(NewValueNode(kNull)); } else { default_values.push_back(ParseExprNode(block, defaults[i])); } @@ -437,6 +438,11 @@ AnfNodePtr Parser::ParseNone(const FunctionBlockPtr &, const py::object &) { return NewValueNode(kNone); } +AnfNodePtr Parser::ParseEllipsis(const FunctionBlockPtr &, const py::object &) { + MS_LOG(DEBUG) << "Process ast Ellipsis"; + return NewValueNode(kEllipsis); +} + AnfNodePtr Parser::ParseNum(const FunctionBlockPtr &, const py::object &node) { MS_LOG(DEBUG) << "Process ast Num"; py::object obj = python_adapter::GetPyObjAttr(node, "n"); diff --git a/mindspore/ccsrc/pipeline/parse/parse.h b/mindspore/ccsrc/pipeline/parse/parse.h index 4dd1bc62aa..be6b09600c 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.h +++ b/mindspore/ccsrc/pipeline/parse/parse.h @@ -92,6 +92,8 @@ class Parser { AnfNodePtr ParseName(const FunctionBlockPtr &block, const py::object &node); // process NoneType AnfNodePtr ParseNone(const FunctionBlockPtr &block, const py::object &node); + // process Ellipsis + AnfNodePtr ParseEllipsis(const FunctionBlockPtr &block, const py::object &node); // process a integer or float number AnfNodePtr ParseNum(const FunctionBlockPtr &block, const py::object &node); // process a string variable diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc index 555a6d87c0..210257ea53 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.cc @@ -892,10 +892,27 @@ bool AbstractNull::operator==(const AbstractBase &other) const { std::string AbstractNull::ToString() const { std::ostringstream buffer; - buffer << type_name() << "(" - << "Value: " - << "Null" - << ")"; + buffer << type_name() << "(Value: Null)"; + return buffer.str(); +} + +bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; } + +bool AbstractEllipsis::operator==(const AbstractBase &other) const { + if (&other == this) { + return true; + } + if (other.isa()) { + auto other_none = static_cast(&other); + return *this == *other_none; + } else { + return false; + } +} + +std::string AbstractEllipsis::ToString() const { + std::ostringstream buffer; + buffer << type_name() << "(Value: Ellipsis)"; return buffer.str(); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h index 9e0dd82003..7608d0bec7 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_value.h @@ -498,7 +498,7 @@ using AbstractNonePtr = std::shared_ptr; // the un assigned state value for variable, which means the variable is not assigned class AbstractNull : public AbstractBase { public: - AbstractNull() : AbstractBase(kNullObj) { set_type(std::make_shared()); } + AbstractNull() : AbstractBase(kNull) { set_type(std::make_shared()); } ~AbstractNull() override = default; MS_DECLARE_PARENT(AbstractNull, AbstractBase) @@ -510,6 +510,20 @@ class AbstractNull : public AbstractBase { }; using AbstractNullPtr = std::shared_ptr; +class AbstractEllipsis : public AbstractBase { + public: + AbstractEllipsis() : AbstractBase(kEllipsis) { set_type(std::make_shared()); } + ~AbstractEllipsis() override = default; + MS_DECLARE_PARENT(AbstractEllipsis, AbstractBase) + + TypePtr BuildType() const override { return std::make_shared(); } + bool operator==(const AbstractEllipsis &other) const; + bool operator==(const AbstractBase &other) const override; + AbstractBasePtr Clone() const override { return std::make_shared(); } + std::string ToString() const override; +}; +using AbstractEllipsisPtr = std::shared_ptr; + class AbstractRefKey : public AbstractBase { public: AbstractRefKey() : AbstractBase() { set_type(std::make_shared()); } diff --git a/mindspore/ops/composite/multitype_ops/getitem_impl.py b/mindspore/ops/composite/multitype_ops/getitem_impl.py index b2b46ebbb1..56617c06a8 100644 --- a/mindspore/ops/composite/multitype_ops/getitem_impl.py +++ b/mindspore/ops/composite/multitype_ops/getitem_impl.py @@ -150,7 +150,7 @@ def _tensor_getitem_by_number(data, number_index): @getitem.register("Tensor", "Slice") def _tensor_getitem_by_slice(data, slice_index): """ - Getting item of tensor by slice index. + Getting item of tensor by slice. Inputs: data (Tensor): A tensor. @@ -165,7 +165,7 @@ def _tensor_getitem_by_slice(data, slice_index): @getitem.register("Tensor", "Tuple") def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): """ - Getting item of tensor by slice tuple index. + Getting item of tensor by slice tuple. Inputs: data (Tensor): A tensor. @@ -175,3 +175,18 @@ def _tensor_getitem_by_slice_tuple(data, slice_tuple_index): Tensor, element type is same as the element type of data. """ return _tensor_slice(data, slice_tuple_index) + + +@getitem.register("Tensor", "Ellipsis") +def _tensor_getitem_by_ellipsis(data, ellipsis_index): + """ + Getting item of tensor by Ellipsis. + + Inputs: + data (Tensor): A tensor. + ellipsis (Ellipsis): A Ellipsis object. + + Outputs: + Tensor, same as data. + """ + return _tensor_slice(data, ellipsis_index) diff --git a/mindspore/ops/functional.py b/mindspore/ops/functional.py index d94ef3a11c..c5b8752ae2 100644 --- a/mindspore/ops/functional.py +++ b/mindspore/ops/functional.py @@ -67,6 +67,7 @@ scalar_to_tensor = P.ScalarToTensor() tuple_to_array = P.TupleToArray() scalar_cast = P.ScalarCast() print_ = P.Print() +expand_dims = P.ExpandDims() tuple_setitem = Primitive('tuple_setitem') tuple_getitem = Primitive('tuple_getitem') diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index a88a2d8322..ddd1fb46a1 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -42,6 +42,20 @@ class NetWorkSlicePositive(Cell): return ret0, ret1, ret2, ret3 +class NetWorkSliceEllipsis(Cell): + def __init__(self): + super(NetWorkSliceEllipsis, self).__init__() + self.tensor_ret0 = Tensor(np.ones([2, 7, 8], np.int32)) + self.tensor_ret1 = Tensor(np.ones([6, 7, 8, 9], np.int32)) + self.tensor_ret2 = Tensor(np.ones([1, 6, 7, 8, 9], np.int32)) + + def construct(self, tensor): + ret0 = tensor[0:4:2, ..., 1] + self.tensor_ret0 + ret1 = tensor[...] + self.tensor_ret1 + ret2 = tensor[True] + self.tensor_ret2 + return ret0, ret1, ret2 + + class NetWorkReduceDimension(Cell): def __init__(self): super(NetWorkReduceDimension, self).__init__() @@ -83,7 +97,7 @@ class NetWorkReduceToScalar(Cell): class TensorAssignWithBoolTensorIndex(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex, self).__init__() - self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) + self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) def construct(self, a, b, c, u_tensor, _scalar): a[c] = u_scalar @@ -104,14 +118,14 @@ class TensorAssignWithBoolTensorIndexError(Cell): class TensorAssignWithBoolTensorIndex2(Cell): def __init__(self): super(TensorAssignWithBoolTensorIndex2, self).__init__() - self.t = Tensor(np.arange(6).reshape([2,3]), dtype = mstype.float64) + self.t = Tensor(np.arange(6).reshape([2, 3]), dtype=mstype.float64) def construct(self, a, u_tensor, _scalar): - a[a>8] = u_tensor - a[a>=6] = u_scalar - a[a<3] = u_scalar - a[a<=5] = u_tensor - a[a==5] = u_scalar + a[a > 8] = u_tensor + a[a >= 6] = u_scalar + a[a < 3] = u_scalar + a[a <= 5] = u_tensor + a[a == 5] = u_scalar z = a + self.t return z @@ -121,11 +135,11 @@ class TensorAssignWithBoolTensorIndex2Error(Cell): super(TensorAssignWithBoolTensorIndex2Error, self).__init__() def construct(self, a, u_tensor): - a[a>8][a>5] = u_tensor + a[a > 8][a > 5] = u_tensor return a -a = np.random.uniform(1,10,[2,3]) +a = np.random.uniform(1, 10, [2, 3]) b = a > 5 c = a < 3 Ta = Tensor(a) @@ -152,7 +166,7 @@ def test_tensor_assign_bool_index(): net1(Ta, Tb, Ta, u_tensor, u_scalar) with pytest.raises(ValueError): net1(Ta, Tb, Tc, u_tensor_error, u_scalar) - #net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) + # net1(Ta, u_tensor, Tc, u_tensor_error, u_scalar) with pytest.raises(ValueError): net2(Ta, u_tensor_error, u_scalar) net3 = TensorAssignWithBoolTensorIndexError() @@ -192,7 +206,10 @@ test_cases = [ 'block': NetWorkReduceToScalar(), 'desc_inputs': [Tensor(np.ones([6, 8, 10], np.int32))], }), - + ('NetWorkSliceEllipsis', { + 'block': NetWorkSliceEllipsis(), + 'desc_inputs': [Tensor(np.ones([6, 7, 8, 9], np.int32))], + }), ] diff --git a/tests/ut/python/pipeline/parse/test_operator.py b/tests/ut/python/pipeline/parse/test_operator.py index 6ae02fa96b..19d70b20a1 100644 --- a/tests/ut/python/pipeline/parse/test_operator.py +++ b/tests/ut/python/pipeline/parse/test_operator.py @@ -162,14 +162,15 @@ def test_ops(): if self.int > self.float: if [1, 2, 3] != None: if self.str_a + self.str_b == "helloworld": - print("hello world") - return ret + if q == 86: + print("hello world") + return ret return x net = OpsNet(9, 2) x = Tensor(np.random.randint(low=1, high=10, size=(2, 3, 4), dtype=np.int32)) y = Tensor(np.random.randint(low=10, high=20, size=(2, 3, 4), dtype=np.int32)) - context.set_context(mode=context.GRAPH_MODE, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE) net(x, y) From d902ec8a2e9c0b9b6486d91e3c2843ad1a079f25 Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Wed, 22 Apr 2020 11:50:24 +0800 Subject: [PATCH 092/142] modify compilation option description --- build.sh | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/build.sh b/build.sh index 7550d76c8f..b48014ed93 100755 --- a/build.sh +++ b/build.sh @@ -23,30 +23,30 @@ export BUILD_PATH="${BASEPATH}/build/" usage() { echo "Usage:" - echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge|cpu] [-m infer|train] \\" - echo " [-a on|off] [-g on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" + echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\" + echo " [-a on|off] [-Q on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|d|cpu] \\" echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I] [-K]" echo "" echo "Options:" echo " -d Debug mode" echo " -r Release mode, default mode" echo " -v Display build command" - echo " -c Enable code coverage switch, default off" - echo " -t Run testcases switch, default on" + echo " -c Enable code coverage, default off" + echo " -t Run testcases, default on" echo " -g Use glog to output log, default on" echo " -h Print usage" echo " -b Select other backend, available: \\" - echo " ge:graph engine, cpu" - echo " -m Select mode, available: infer, train, default is infer " + echo " ge:graph engine" + echo " -m Select graph engine backend mode, available: infer, train, default is infer" echo " -a Enable ASAN, default off" - echo " -p Enable pipeline profile, default off" + echo " -p Enable pipeline profile, print to stdout, default off" + echo " -R Enable pipeline profile, record to json, default off" echo " -i Enable increment building, default off" echo " -L Enable load ANF-IR as input of 'infer', default off" - echo " -R Enable the time_line record, default off" echo " -j[n] Set the threads when building (Default: -j8)" echo " -e Use gpu, d or cpu" echo " -P Enable dump anf graph to file in ProtoBuffer format, default on" - echo " -Q Enable dump end to end, default off" + echo " -Q Enable dump memory, default off" echo " -D Enable dumping of function graph ir, default on" echo " -z Compile dataset & mindrecord, default on" echo " -M Enable MPI and NCCL for GPU training, default on" From 174bfadcb2dcec87290ff0ded7ebccd8783dfb7a Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 22 Apr 2020 15:28:56 +0800 Subject: [PATCH 093/142] modify ReduceAll to dock ReduceAllD in GE --- mindspore/ccsrc/transform/convert.cc | 2 +- mindspore/ccsrc/transform/op_declare.cc | 15 ++++++--------- mindspore/ccsrc/transform/op_declare.h | 7 +++---- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/mindspore/ccsrc/transform/convert.cc b/mindspore/ccsrc/transform/convert.cc index 417989247e..3b05dbf3ec 100755 --- a/mindspore/ccsrc/transform/convert.cc +++ b/mindspore/ccsrc/transform/convert.cc @@ -269,7 +269,7 @@ std::unordered_map &DfGraphConvertor::get_adpt_ma {string(kNameArgMinWithValue), ADPT_DESC(ArgMinWithValue)}, {prim::kPrimReduceSum->name(), ADPT_DESC(ReduceSumD)}, {prim::kPrimReduceMean->name(), ADPT_DESC(ReduceMeanD)}, - {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAll)}, + {prim::kPrimReduceAll->name(), ADPT_DESC(ReduceAllD)}, {prim::kPrimReduceMin->name(), ADPT_DESC(ReduceMinD)}, {prim::kPrimReduceMax->name(), ADPT_DESC(ReduceMaxD)}, {string(kNameLARSUpdate), ADPT_DESC(LarsV2Update)}, diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 420edc685a..f39d7e4223 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -268,11 +268,6 @@ INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_D ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; -// ReduceSum -INPUT_MAP(ReduceSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; -ATTR_MAP(ReduceSum) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}}; - // ReduceSumD INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; INPUT_ATTR_MAP(ReduceSumD) = { @@ -653,10 +648,12 @@ ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; -// ReduceAll -INPUT_MAP(ReduceAll) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axis)}}; -ATTR_MAP(ReduceAll) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceAll) = {{0, OUTPUT_DESC(y)}}; +// ReduceAllD +INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceAllD) = { + {2, ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; // ReduceMeanD INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; diff --git a/mindspore/ccsrc/transform/op_declare.h b/mindspore/ccsrc/transform/op_declare.h index 8b32e16b35..3be3546455 100755 --- a/mindspore/ccsrc/transform/op_declare.h +++ b/mindspore/ccsrc/transform/op_declare.h @@ -346,10 +346,9 @@ DECLARE_OP_USE_OUTPUT(Sin) DECLARE_OP_ADAPTER(Exp) DECLARE_OP_USE_OUTPUT(Exp) -DECLARE_OP_ADAPTER(ReduceAll) -DECLARE_OP_USE_OUTPUT(ReduceAll) -DECLARE_OP_ADAPTER(ReduceSum) -DECLARE_OP_USE_OUTPUT(ReduceSum) +DECLARE_OP_ADAPTER(ReduceAllD) +DECLARE_OP_USE_INPUT_ATTR(ReduceAllD) +DECLARE_OP_USE_OUTPUT(ReduceAllD) DECLARE_OP_ADAPTER(ReduceSumD) DECLARE_OP_USE_INPUT_ATTR(ReduceSumD) DECLARE_OP_USE_OUTPUT(ReduceSumD) From 5b6c8fade8e723973be598d5686f350659c6ea06 Mon Sep 17 00:00:00 2001 From: liuxiao Date: Tue, 21 Apr 2020 09:31:39 +0800 Subject: [PATCH 094/142] Add Erf\FillD operator for VM --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + .../pass/const_input_to_attr_registry.cc | 1 + mindspore/ccsrc/utils/utils.h | 1 + mindspore/ops/_grad/grad_math_ops.py | 18 ++++++ mindspore/ops/_op_impl/tbe/__init__.py | 2 + mindspore/ops/_op_impl/tbe/erf.py | 39 +++++++++++++ mindspore/ops/_op_impl/tbe/fill_d.py | 55 +++++++++++++++++++ mindspore/ops/operations/__init__.py | 3 +- mindspore/ops/operations/math_ops.py | 30 ++++++++++ tests/ut/python/ops/test_ops.py | 4 ++ 10 files changed, 153 insertions(+), 1 deletion(-) create mode 100644 mindspore/ops/_op_impl/tbe/erf.py create mode 100644 mindspore/ops/_op_impl/tbe/fill_d.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 3fda554759..17ac8742f9 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -57,6 +57,7 @@ static std::map tbe_func_adapter_map = { {"strided_slice", "strided_slice_d"}, {"strided_slice_grad", "strided_slice_grad_d"}, {"transpose", "transpose_d"}, + {"fill", "fill_d"}, {"unsorted_segment_sum", "unsorted_segment_sum_d"}, {"concat", "concat_d"}, {"slice", "slice_d"}, diff --git a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc index c2f96e54c6..fb47c9fc2a 100644 --- a/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc +++ b/mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc @@ -53,6 +53,7 @@ ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() { Register(kExpandDimsOpName, {1}); Register(kSplitOpName, {0}); Register(kTopKOpName, {1}); + Register(kErfOpName, {1}); Register(kSparseApplyAdagradOpName, {2}); Register(kResizeNearestNeighborGrad, {1}); } diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index f05eda69bf..6829a7e888 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -92,6 +92,7 @@ constexpr auto kClipByNormNoDivSumOpName = "ClipByNormNoDivSum"; constexpr auto kGreaterOpName = "Greater"; constexpr auto kSqrtOpName = "Sqrt"; constexpr auto kRsqrtOpName = "Rsqrt"; +constexpr auto kErfOpName = "Erf"; constexpr auto kRealDivOpName = "RealDiv"; constexpr auto kLambUpdateWithLROpName = "LambUpdateWithLR"; constexpr auto kLambNextMVWithDecayOpName = "LambNextMVWithDecay"; diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index 2d819718c8..c334050218 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -17,6 +17,7 @@ from functools import reduce +import numpy as np from .. import functional as F from .. import operations as P from ..operations import _grad_ops as G @@ -333,6 +334,23 @@ def get_bprop_log(self): return bprop +@bprop_getters.register(P.Erf) +def get_bprop_erf(self): + """Grad definition for `Erf` operation.""" + exp = P.Exp() + square = P.Square() + sqrt = P.Sqrt() + cast = P.Cast() + dtype = P.DType() + + def bprop(x, out, dout): + half_root_pi = cast(2 / sqrt(F.scalar_to_tensor(np.pi)), dtype(x)) + x_square = square(x) + dx = dout * half_root_pi * exp(-x_square) + return (dx,) + return bprop + + @bprop_getters.register(P.Pow) def get_bprop_pow(self): """Grad definition for `Pow` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 37da184869..18ef92ca6e 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -139,6 +139,8 @@ from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe from .fused_mul_add import _fused_mul_add_tbe from .fused_mul_add_n import _fused_mul_add_n_tbe from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe +from .fill_d import _fill_d_op_tbe +from .erf import _erf_op_tbe from .depthwise_conv2d import _depthwise_conv2d_tbe from .depthwise_conv2d_backprop_filter import _depthwise_conv2d_backprop_filter_tbe from .depthwise_conv2d_backprop_input import _depthwise_conv2d_backprop_input_tbe diff --git a/mindspore/ops/_op_impl/tbe/erf.py b/mindspore/ops/_op_impl/tbe/erf.py new file mode 100644 index 0000000000..2247197c4e --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/erf.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Erf op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +erf_op_info = TBERegOp("Erf") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("erf.so") \ + .compute_cost(10) \ + .kernel_name("erf") \ + .partial_flag(True) \ + .op_pattern("formatAgnostic") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(erf_op_info) +def _erf_op_tbe(): + """Erf TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/fill_d.py b/mindspore/ops/_op_impl/tbe/fill_d.py new file mode 100644 index 0000000000..97c6b73cf5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/fill_d.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""FillD op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +fill_d_op_info = TBERegOp("FillD") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("fill_d.so") \ + .compute_cost(10) \ + .kernel_name("fill_d") \ + .partial_flag(True) \ + .attr("dims", "required", "listInt", "all") \ + .input(0, "value", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_FracZ, DataType.I8_FracZ) \ + .dtype_format(DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_FracZ, DataType.U8_FracZ) \ + .dtype_format(DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0) \ + .dtype_format(DataType.U8_Default, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(fill_d_op_info) +def _fill_d_op_tbe(): + """FillD TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 2860690b91..80b03a04e1 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -39,7 +39,7 @@ from .control_ops import ControlDepend, GeSwitch, Merge from .inner_ops import ScalarCast from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul, ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, - Cos, Div, Equal, EqualCount, Exp, Floor, FloorDiv, FloorMod, Acosh, + Cos, Div, Equal, EqualCount, Exp, Erf, Floor, FloorDiv, FloorMod, Acosh, Greater, GreaterEqual, Less, LessEqual, Log, LogicalAnd, LogicalNot, LogicalOr, MatMul, Maximum, Minimum, Mul, Neg, NMSWithMask, NotEqual, @@ -139,6 +139,7 @@ __all__ = [ 'ReLU', 'ReLU6', 'Elu', + 'Erf', 'Sigmoid', 'HSwish', 'HSigmoid', diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 33351a3ca1..8de4108435 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -1007,6 +1007,36 @@ class Log(PrimitiveWithInfer): return x +class Erf(PrimitiveWithInfer): + r""" + Computes the Gauss error function of `input_x` element-wise. + + Inputs: + - **input_x** (Tensor) - The input tensor. + + Outputs: + Tensor, has the same shape and dtype as the `input_x`. + + Examples: + >>> input_x = Tensor(np.array([-1, 0, 1, 2, 3]), mindspore.float32) + >>> erf = P.Erf() + >>> erf(input_x) + [-0.8427168, 0., 0.8427168, 0.99530876, 0.99997765] + """ + + @prim_attr_register + def __init__(self): + """init Erf""" + self.init_prim_io_names(inputs=['x'], outputs=['y']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_type): + validator.check_tensor_type_same({"x": x_type}, [mstype.float16, mstype.float32], self.name) + return x_type + + class Minimum(_MathBinaryOp): """ Computes the element-wise minimum of input tensors. diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 1bd3a2e438..442c8bdec6 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -250,6 +250,10 @@ test_case_math_ops = [ 'block': P.Exp(), 'desc_inputs': [[2, 3]], 'desc_bprop': [[2, 3]]}), + ('Erf', { + 'block': P.Erf(), + 'desc_inputs': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))], + 'desc_bprop': [Tensor(np.array([-2, -1, 0, 1, 2]).astype(np.float16))]}), ('Floor', { 'block': P.Floor(), 'desc_inputs': [[2, 512, 56, 56]], From 6770c66ed9f8496b5f89b165731c64deb15b062e Mon Sep 17 00:00:00 2001 From: fary86 Date: Thu, 16 Apr 2020 04:06:32 +0800 Subject: [PATCH 095/142] Add prim name to error message for other operators left --- mindspore/ops/operations/_quant_ops.py | 297 ++++++++++-------------- mindspore/ops/operations/comm_ops.py | 38 ++- mindspore/ops/operations/control_ops.py | 7 +- mindspore/ops/operations/debug_ops.py | 4 +- mindspore/ops/operations/other_ops.py | 71 +++--- mindspore/ops/operations/random_ops.py | 15 +- 6 files changed, 179 insertions(+), 253 deletions(-) diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py index 14d1bc9234..4c7d64b581 100644 --- a/mindspore/ops/operations/_quant_ops.py +++ b/mindspore/ops/operations/_quant_ops.py @@ -15,8 +15,8 @@ """Operators for quantization.""" -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_bool, check_int_positive, check_int +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ..primitive import PrimitiveWithInfer, prim_attr_register from ...common import dtype as mstype @@ -69,36 +69,31 @@ class FakeQuantWithMinMax(PrimitiveWithInfer): training=True): """init FakeQuantWithMinMax OP""" if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") if ema and not ema_decay: - raise ValueError( - "Attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = check_bool(ema) - self.symmetric = check_bool(symmetric) - self.narrow_range = check_bool(narrow_range) - self.training = check_bool(training) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) - self.num_bits = check_int_positive(num_bits) - self.quant_delay = check_int(quant_delay) + raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type('training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x shape", len(x_shape), 1, Rel.GT) - validator.check("min shape", min_shape, "max shape", max_shape) - validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) - validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): - validator.check_typename( - "x type", x_type, (mstype.float16, mstype.float32)) - validator.check_typename("min type", min_type, - (mstype.float16, mstype.float32)) - validator.check_typename("max type", max_type, - (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return x_type @@ -109,29 +104,24 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, num_bits=8, quant_delay=0): if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - self.quant_delay = check_int(quant_delay) - self.num_bits = check_int_positive(num_bits) - self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], - outputs=['dx']) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): - validator.check("dout shape", dout_shape, "x shape", x_shape) - validator.check("min shape", min_shape, "max shape", max_shape) - validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) - validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) + validator.check("dout shape", dout_shape, "x shape", x_shape, Rel.EQ, self.name) + validator.check("min shape", min_shape, "max shape", max_shape, Rel.EQ, self.name) + validator.check_integer("min rank", len(min_shape), 1, Rel.EQ, self.name) return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - validator.check_typename( - "dout type", dout_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "x type", x_type, (mstype.float16, mstype.float32)) - validator.check_typename("min type", min_type, - (mstype.float16, mstype.float32)) - validator.check_typename("max type", max_type, - (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return dout_type @@ -172,37 +162,30 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): training=True): """init FakeQuantWithMinMaxPerChannel OP""" if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' Attr \'num_bits\' is not support.") if ema and not ema_decay: - raise ValueError( - "Attr \'ema\' and \'ema_decay\' should set together.") - - self.ema = check_bool(ema) - self.symmetric = check_bool(symmetric) - self.narrow_range = check_bool(narrow_range) - self.training = check_bool(training) - self.ema_decay = validator.check_number_range( - 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) - self.num_bits = check_int_positive(num_bits) - self.quant_delay = check_int(quant_delay) - self.init_prim_io_names(inputs=['x', 'min', 'max'], - outputs=['out']) + raise ValueError(f"For '{self.name}' attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = validator.check_value_type('ema', ema, (bool,), self.name) + self.symmetric = validator.check_value_type('symmetric', symmetric, (bool,), self.name) + self.narrow_range = validator.check_value_type('narrow_range', narrow_range, (bool,), self.name) + self.training = validator.check_value_type('training', training, (bool,), self.name) + self.ema_decay = validator.check_number_range('ema_decay', ema_decay, 0, 1, Rel.INC_BOTH, self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.init_prim_io_names(inputs=['x', 'min', 'max'], outputs=['out']) def infer_shape(self, x_shape, min_shape, max_shape): - validator.check_integer("x shape", len(x_shape), 1, Rel.GT) - validator.check_integer( - "min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ) - validator.check_integer( - "max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ) + validator.check_integer("x rank", len(x_shape), 1, Rel.GT, self.name) + validator.check_integer("min shape[0]", min_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) + validator.check_integer("max shape[0]", max_shape[0], x_shape[self.channel_idx], Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, min_type, max_type): - validator.check_typename( - "x type", x_type, (mstype.float16, mstype.float32)) - validator.check_typename("min type", min_type, - (mstype.float16, mstype.float32)) - validator.check_typename("max type", max_type, - (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return x_type @@ -214,12 +197,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): def __init__(self, num_bits=8, quant_delay=0): """init FakeQuantWithMinMaxPerChannel Fill""" if num_bits not in self.support_quant_bit: - raise ValueError("Attr \'num_bits\' is not support.") + raise ValueError(f"For '{self.name}' attr \'num_bits\' is not support.") - self.quant_delay = check_int(quant_delay) - self.num_bits = check_int_positive(num_bits) - self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], - outputs=['dx']) + self.quant_delay = validator.check_value_type('quant_delay', quant_delay, (int,), self.name) + self.num_bits = validator.check_integer('num_bits', num_bits, 0, Rel.GT, self.name) + self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], outputs=['dx']) def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): validator.check("dout shape", dout_shape, "x shape", x_shape) @@ -227,13 +209,11 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): return dout_shape def infer_dtype(self, dout_type, x_type, min_type, max_type): - validator.check_typename( - "dout", dout_type, (mstype.float16, mstype.float32)) - validator.check_typename("x", x_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "min", min_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "max", max_type, (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"dout": dout_type}, valid_types, self.name) + validator.check_tensor_type_same({"x": x_type}, valid_types, self.name) + validator.check_tensor_type_same({"min": min_type}, valid_types, self.name) + validator.check_tensor_type_same({"max": max_type}, valid_types, self.name) return dout_type @@ -269,31 +249,26 @@ class BatchNormFold(PrimitiveWithInfer): @prim_attr_register def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0): """init batch norm fold layer""" - self.momentum = validator.check_number_range( - 'momentum', momentum, 0, 1, Rel.INC_BOTH) - self.epsilon = validator.check_float_positive('epsilon', epsilon) - self.is_training = check_bool(is_training) - self.freeze_bn = check_int(freeze_bn) + self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) + self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) + self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) + self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'], outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std']) def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): - validator.check("mean shape", mean_shape, - "gamma_shape", variance_shape) - validator.check("mean_shape size", - mean_shape[0], "input channel", x_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + validator.check("mean shape", mean_shape, "gamma_shape", variance_shape, Rel.EQ, self.name) + validator.check("mean_shape[0]", mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return mean_shape, mean_shape, mean_shape, mean_shape def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): validator.check("input type", x_type, "mean type", mean_type) validator.check("input type", x_type, "variance type", variance_type) - validator.check_typename("input type", x_type, - (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"x": x_type, "mean": mean_type, "variance": variance_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return x_type, x_type, x_type, x_type @@ -304,39 +279,31 @@ class BatchNormFoldGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0): """init BatchNormGrad layer""" - self.is_training = check_bool(is_training) - self.freeze_bn = check_int(freeze_bn) - self.epsilon = validator.check_float_positive('epsilon', epsilon) + self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) + self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) + self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], outputs=['dx']) def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape, global_step_shape): validator.check("d_batch_mean shape", d_batch_mean_shape, - "d_batch_std shape", d_batch_std_shape) + "d_batch_std shape", d_batch_std_shape, Rel.EQ, self.name) validator.check("d_batch_mean shape", d_batch_mean_shape, - "batch_mean shape", batch_mean_shape) + "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) validator.check("d_batch_mean shape", d_batch_mean_shape, - "batch_std shape", batch_std_shape) - validator.check( - "x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + "batch_std shape", batch_std_shape, Rel.EQ, self.name) + validator.check("d_batch_mean_shape[0]", d_batch_mean_shape[0], "input channel", x_shape[self.channel], Rel.EQ, + self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, global_step_type): - validator.check("input type", x_type, - "d_batch_mean type", d_batch_mean_type) - validator.check("input type", x_type, - "d_batch_std type", d_batch_std_type) - validator.check("input type", x_type, - "batch_mean type", batch_mean_type) - validator.check("input type", x_type, "batch_std type", batch_std_type) - validator.check_typename("input type", x_type, - (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"input": x_type, "d_batch_mean": d_batch_mean_type, "d_batch_std": d_batch_std_type, + "batch_mean": batch_mean_type, "batch_std": batch_std_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return x_type @@ -364,18 +331,14 @@ class CorrectionMul(PrimitiveWithInfer): outputs=['out']) def infer_shape(self, x_shape, batch_std_shape, running_std_shape): - validator.check("batch_std shape", batch_std_shape, - "running_std shape", running_std_shape) - validator.check( - "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) + validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) + validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], + Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, batch_std_type, running_std_type): - validator.check("batch_std type", batch_std_type, - "running_std type", running_std_type) - validator.check("batch_std_type", batch_std_type, "x_type", x_type) - validator.check_typename( - "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) + args = {"x": x_type, "batch_std": batch_std_type, "running_std": running_std_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return x_type @@ -390,20 +353,16 @@ class CorrectionMulGrad(PrimitiveWithInfer): outputs=['dx', 'd_gamma']) def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): - validator.check("dout shape", dout_shape, "x_shape x", x_shape) - validator.check( - "gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel]) - validator.check( - "running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel]) + validator.check("dout shape", dout_shape, "x_shape x", x_shape, Rel.EQ, self.name) + validator.check("gamma_shape[0]", gamma_shape[0], "dout channel size", dout_shape[self.channel], + Rel.EQ, self.name) + validator.check("running_std_shape[0]", running_std_shape[0], "dout channel size", dout_shape[self.channel], + Rel.EQ, self.name) return x_shape, gamma_shape def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): - validator.check("x type", x_type, "dout type", dout_type) - validator.check("gamma type", gamma_type, "dout type", dout_type) - validator.check("running_std type", running_std_type, - "dout type", dout_type) - validator.check_typename( - "dout type", dout_type, (mstype.float16, mstype.float32)) + args = {"dout": dout_type, "x": x_type, "gamma": gamma_type, "running_std": running_std_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) return x_type, x_type @@ -432,46 +391,29 @@ class BatchNormFold2(PrimitiveWithInfer): @prim_attr_register def __init__(self, freeze_bn=0): """init conv2d fold layer""" - self.freeze_bn = check_int(freeze_bn) + self.freeze_bn = validator.check_value_type('freeze_bn', freeze_bn, (int,), self.name) self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', 'running_std', 'running_mean', 'global_step'], outputs=['y']) def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape, running_mean_shape, global_step_shape): - validator.check("batch_std shape", batch_std_shape, - "running_std shape", running_std_shape) - validator.check("batch_std shape", batch_std_shape, - "batch_mean shape", batch_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "beta shape", beta_shape) - validator.check("batch_std shape", batch_std_shape, - "running_mean shape", running_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "batch_mean shape", gamma_shape) - validator.check( - "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "beta shape", beta_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "batch_mean shape", gamma_shape, Rel.EQ, self.name) + validator.check("batch_std_shape[0]", batch_std_shape[0], "x_shape channel size", x_shape[self.channel], + Rel.EQ, self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, running_mean_type, global_step_type): - validator.check("batch_std type", batch_std_type, - "running_std type", running_std_type) - validator.check("batch_std type", batch_std_type, - "batch_mean type", batch_mean_type) - validator.check("batch_std type", batch_std_type, - "beta type", beta_type) - validator.check("batch_std type", batch_std_type, - "running_mean type", running_mean_type) - validator.check("batch_std type", batch_std_type, - "gamma type", gamma_type) - validator.check("x_type", x_type, "batch_std type", batch_std_type) - validator.check_typename( - "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"batch_std": batch_std_type, "running_std": running_std_type, "batch_mean": batch_mean_type, + "beta": beta_type, "running_mean": running_mean_type, "gamma": gamma_type, "x": x_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return x_type @@ -491,18 +433,13 @@ class BatchNormFold2Grad(PrimitiveWithInfer): def infer_shape(self, dout_shape, x_shape, gamma_shape, batch_std_shape, batch_mean_shape, running_std_shape, running_mean_shape, global_step_shape): - validator.check("batch_std shape", batch_std_shape, - "batch_mean shape", batch_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "running_std shape", running_std_shape) - validator.check("batch_std shape", batch_std_shape, - "running_mean shape", running_mean_shape) - validator.check("batch_std shape", batch_std_shape, - "gamma shape", gamma_shape) - validator.check( - "batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel]) - validator.check_integer("global_step shape", - len(global_step_shape), 1, Rel.EQ) + validator.check("batch_std shape", batch_std_shape, "batch_mean shape", batch_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "running_std shape", running_std_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "running_mean shape", running_mean_shape, Rel.EQ, self.name) + validator.check("batch_std shape", batch_std_shape, "gamma shape", gamma_shape, Rel.EQ, self.name) + validator.check("batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel], + Rel.EQ, self.name) + validator.check_integer("global_step rank", len(global_step_shape), 1, Rel.EQ, self.name) return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape def infer_dtype(self, dout_type, x_type, gamma_type, @@ -518,8 +455,8 @@ class BatchNormFold2Grad(PrimitiveWithInfer): "running_mean type", running_mean_type) validator.check("batch_std_type", batch_std_type, "dout type", dout_type) - validator.check_typename( - "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) - validator.check_typename( - "global_step type", global_step_type, (mstype.int32,)) + args = {"batch_std": batch_std_type, "batch_mean": batch_mean_type, "gamma": gamma_type, + "running_std": running_std_type, "running_mean": running_mean_type, "dout": dout_type} + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"global_step": global_step_type}, (mstype.int32,), self.name) return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index a5a4c9f236..fbad5b49d3 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -15,7 +15,8 @@ """comm_ops""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ...communication.management import get_rank, get_group_size, GlobalComm, get_group from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register @@ -148,12 +149,10 @@ class AllGather(PrimitiveWithInfer): @prim_attr_register def __init__(self, group=GlobalComm.WORLD_COMM_GROUP): - if not isinstance(get_group(group), str): - raise TypeError("The group of AllGather should be str.") + validator.check_value_type('group', get_group(group), (str,), self.name) self.rank = get_rank(get_group(group)) self.rank_size = get_group_size(get_group(group)) - if self.rank >= self.rank_size: - raise ValueError("The rank of AllGather should be less than the rank_size.") + validator.check('rank', self.rank, 'rank_size', self.rank_size, Rel.LT, self.name) self.add_prim_attr('rank_size', self.rank_size) self.add_prim_attr('group', get_group(group)) @@ -163,7 +162,7 @@ class AllGather(PrimitiveWithInfer): def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("AllGather does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -205,10 +204,8 @@ class ReduceScatter(PrimitiveWithInfer): @prim_attr_register def __init__(self, op=ReduceOp.SUM, group=GlobalComm.WORLD_COMM_GROUP): - if not isinstance(op, type(ReduceOp.SUM)): - raise TypeError("The operation of ReduceScatter should be {}.".format(type(ReduceOp.SUM))) - if not isinstance(get_group(group), str): - raise TypeError("The group of ReduceScatter should be str.") + validator.check_value_type('op', op, (type(ReduceOp.SUM),), self.name) + validator.check_value_type('group', get_group(group), (str,), self.name) self.op = op self.rank_size = get_group_size(get_group(group)) self.add_prim_attr('rank_size', self.rank_size) @@ -216,13 +213,13 @@ class ReduceScatter(PrimitiveWithInfer): def infer_shape(self, x_shape): if x_shape[0] % self.rank_size != 0: - raise ValueError("The first dimension of x should be divided by rank_size.") + raise ValueError(f"For '{self.name}' the first dimension of x should be divided by rank_size.") x_shape[0] = int(x_shape[0]/self.rank_size) return x_shape def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("ReduceScatter does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -270,10 +267,8 @@ class Broadcast(PrimitiveWithInfer): @prim_attr_register def __init__(self, root_rank, group=GlobalComm.WORLD_COMM_GROUP): - if not isinstance(root_rank, int): - raise TypeError("The root_rank of Broadcast should be int.") - if not isinstance(get_group(group), str): - raise TypeError("The group of Broadcast should be str.") + validator.check_value_type('root_rank', root_rank, (int,), self.name) + validator.check_value_type('group', get_group(group), (str,), self.name) self.add_prim_attr('group', get_group(group)) def infer_shape(self, x_shape): @@ -281,7 +276,7 @@ class Broadcast(PrimitiveWithInfer): def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("Broadcast does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype @@ -311,8 +306,7 @@ class _AlltoAll(PrimitiveWithInfer): @prim_attr_register def __init__(self, split_count, split_dim, concat_dim, group=GlobalComm.WORLD_COMM_GROUP): """init AlltoAll""" - if not isinstance(get_group(group), str): - raise TypeError("The group of AllGather should be str.") + validator.check_value_type('group', get_group(group), (str,), self.name) self.split_count = split_count self.split_dim = split_dim self.concat_dim = concat_dim @@ -325,7 +319,7 @@ class _AlltoAll(PrimitiveWithInfer): def infer_dtype(self, x_dtype): if x_dtype == mstype.bool_: - raise TypeError("AlltoAll does not support 'Bool' as the dtype of input!") + raise TypeError(f"{self.name} does not support 'Bool' as the dtype of input!") return x_dtype def __call__(self, tensor): @@ -420,6 +414,6 @@ class _GetTensorSlice(PrimitiveWithInfer): def infer_value(self, x, dev_mat, tensor_map): from mindspore.parallel._tensor import _load_tensor - validator.check_type("dev_mat", dev_mat, [tuple]) - validator.check_type("tensor_map", tensor_map, [tuple]) + validator.check_value_type("dev_mat", dev_mat, [tuple], self.name) + validator.check_value_type("tensor_map", tensor_map, [tuple], self.name) return _load_tensor(x, dev_mat, tensor_map) diff --git a/mindspore/ops/operations/control_ops.py b/mindspore/ops/operations/control_ops.py index ca161cfad0..9743f9e3fd 100644 --- a/mindspore/ops/operations/control_ops.py +++ b/mindspore/ops/operations/control_ops.py @@ -16,7 +16,8 @@ """control_ops""" from ...common import dtype as mstype -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator +from ..._checkparam import Rel from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register @@ -123,11 +124,11 @@ class GeSwitch(PrimitiveWithInfer): raise NotImplementedError def infer_shape(self, data, pred): - validator.check_scalar_shape_input("pred", pred) + validator.check_integer("pred rank", len(pred), 0, Rel.EQ, self.name) return (data, data) def infer_dtype(self, data_type, pred_type): - validator.check_type("pred", pred_type, [type(mstype.bool_)]) + validator.check_tensor_type_same({"pred": pred_type}, [mstype.bool_], self.name) return (data_type, data_type) diff --git a/mindspore/ops/operations/debug_ops.py b/mindspore/ops/operations/debug_ops.py index 97fa883bac..21c9c519b9 100644 --- a/mindspore/ops/operations/debug_ops.py +++ b/mindspore/ops/operations/debug_ops.py @@ -14,7 +14,7 @@ # ============================================================================ """debug_ops""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ...common import dtype as mstype from ..primitive import Primitive, prim_attr_register, PrimitiveWithInfer @@ -219,5 +219,5 @@ class Print(PrimitiveWithInfer): def infer_dtype(self, *inputs): for dtype in inputs: - validator.check_subclass("input", dtype, (mstype.tensor, mstype.string)) + validator.check_subclass("input", dtype, (mstype.tensor, mstype.string), self.name) return mstype.int32 diff --git a/mindspore/ops/operations/other_ops.py b/mindspore/ops/operations/other_ops.py index 2ece6b7088..b98e7df77d 100644 --- a/mindspore/ops/operations/other_ops.py +++ b/mindspore/ops/operations/other_ops.py @@ -16,7 +16,7 @@ """Other operators.""" from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind -from ..._checkparam import ParamValidator as validator, Rel +from ..._checkparam import Validator as validator, Rel from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register @@ -82,22 +82,21 @@ class BoundingBoxEncode(PrimitiveWithInfer): @prim_attr_register def __init__(self, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)): - validator.check_type('means', means, [tuple]) - validator.check_type('stds', stds, [tuple]) - validator.check("means len", len(means), '', 4) - validator.check("stds len", len(stds), '', 4) + validator.check_value_type('means', means, [tuple], self.name) + validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) + validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) def infer_shape(self, anchor_box, groundtruth_box): - validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0]) - validator.check('anchor_box shape[1]', anchor_box[1], '', 4) - validator.check('groundtruth_box shape[1]', groundtruth_box[1], '', 4) + validator.check('anchor_box shape[0]', anchor_box[0], 'groundtruth_box shape[0]', groundtruth_box[0], Rel.EQ, + self.name) + validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) + validator.check_integer('groundtruth_box shape[1]', groundtruth_box[1], 4, Rel.EQ, self.name) return anchor_box def infer_dtype(self, anchor_box, groundtruth_box): - args = {"anchor_box": anchor_box, - "groundtruth_box": groundtruth_box - } - validator.check_type_same(args, mstype.number_type) + args = {"anchor_box": anchor_box, "groundtruth_box": groundtruth_box} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return anchor_box @@ -126,26 +125,24 @@ class BoundingBoxDecode(PrimitiveWithInfer): @prim_attr_register def __init__(self, max_shape, means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0), wh_ratio_clip=0.016): - validator.check_type('means', means, [tuple]) - validator.check_type('stds', stds, [tuple]) - validator.check_type('wh_ratio_clip', wh_ratio_clip, [float]) - validator.check("means", len(means), '', 4) - validator.check("stds", len(stds), '', 4) + validator.check_value_type('means', means, [tuple], self.name) + validator.check_value_type('stds', stds, [tuple], self.name) + validator.check_value_type('wh_ratio_clip', wh_ratio_clip, [float], self.name) + validator.check_integer("means len", len(means), 4, Rel.EQ, self.name) + validator.check_integer("stds len", len(stds), 4, Rel.EQ, self.name) if max_shape is not None: - validator.check_type('max_shape', max_shape, [tuple]) - validator.check("max_shape", len(max_shape), '', 2) + validator.check_value_type('max_shape', max_shape, [tuple], self.name) + validator.check_integer("max_shape len", len(max_shape), 2, Rel.EQ, self.name) def infer_shape(self, anchor_box, deltas): - validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0]) - validator.check('anchor_box shape[1]', anchor_box[1], '', 4) - validator.check('deltas shape[1]', deltas[1], '', 4) + validator.check('anchor_box shape[0]', anchor_box[0], 'deltas shape[0]', deltas[0], Rel.EQ, self.name) + validator.check_integer('anchor_box shape[1]', anchor_box[1], 4, Rel.EQ, self.name) + validator.check_integer('deltas shape[1]', deltas[1], 4, Rel.EQ, self.name) return anchor_box def infer_dtype(self, anchor_box, deltas): - args = {"anchor_box": anchor_box, - "deltas": deltas - } - validator.check_type_same(args, mstype.number_type) + args = {"anchor_box": anchor_box, "deltas": deltas} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return anchor_box @@ -168,10 +165,10 @@ class CheckValid(PrimitiveWithInfer): self.init_prim_io_names(inputs=['bboxes', 'img_metas'], outputs=['output']) def infer_shape(self, bboxes_shape, metas_shape): - validator.check_shape_length("bboxes shape length", len(bboxes_shape), 2, Rel.EQ) - validator.check("bboxes_shape[-1]", bboxes_shape[-1], "", 4, Rel.EQ) - validator.check_shape_length("img_metas shape length", len(metas_shape), 1, Rel.EQ) - validator.check("img_metas shape[0]", metas_shape[0], "", 3, Rel.EQ) + validator.check_integer("bboxes rank", len(bboxes_shape), 2, Rel.EQ, self.name) + validator.check_integer("bboxes_shape[-1]", bboxes_shape[-1], 4, Rel.EQ, self.name) + validator.check_integer("img_metas rank", len(metas_shape), 1, Rel.EQ, self.name) + validator.check_integer("img_metas shape[0]", metas_shape[0], 3, Rel.EQ, self.name) return bboxes_shape[:-1] def infer_dtype(self, bboxes_type, metas_type): @@ -221,18 +218,16 @@ class IOU(PrimitiveWithInfer): self.init_prim_io_names(inputs=['anchor_boxes', 'gt_boxes'], outputs=['overlap']) def infer_shape(self, anchor_boxes, gt_boxes): - validator.check('gt_boxes shape[1]', gt_boxes[1], '', 4) - validator.check('anchor_boxes shape[1]', anchor_boxes[1], '', 4) - validator.check('anchor_boxes rank', len(anchor_boxes), '', 2) - validator.check('gt_boxes rank', len(gt_boxes), '', 2) + validator.check_integer('gt_boxes shape[1]', gt_boxes[1], 4, Rel.EQ, self.name) + validator.check_integer('anchor_boxes shape[1]', anchor_boxes[1], 4, Rel.EQ, self.name) + validator.check_integer('anchor_boxes rank', len(anchor_boxes), 2, Rel.EQ, self.name) + validator.check_integer('gt_boxes rank', len(gt_boxes), 2, Rel.EQ, self.name) iou = [gt_boxes[0], anchor_boxes[0]] return iou def infer_dtype(self, anchor_boxes, gt_boxes): - validator.check_subclass("anchor_boxes", anchor_boxes, mstype.tensor) - validator.check_subclass("gt_boxes", gt_boxes, mstype.tensor) args = {"anchor_boxes": anchor_boxes, "gt_boxes": gt_boxes} - validator.check_type_same(args, (mstype.float16,)) + validator.check_tensor_type_same(args, (mstype.float16,), self.name) return anchor_boxes @@ -270,7 +265,7 @@ class MakeRefKey(Primitive): @prim_attr_register def __init__(self, tag): - validator.check_type('tag', tag, (str,)) + validator.check_value_type('tag', tag, (str,), self.name) def __call__(self): pass diff --git a/mindspore/ops/operations/random_ops.py b/mindspore/ops/operations/random_ops.py index 18c2212b3d..2692b43b46 100644 --- a/mindspore/ops/operations/random_ops.py +++ b/mindspore/ops/operations/random_ops.py @@ -15,7 +15,7 @@ """Operators for random.""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import PrimitiveWithInfer, prim_attr_register @@ -52,16 +52,15 @@ class RandomChoiceWithMask(PrimitiveWithInfer): @prim_attr_register def __init__(self, count=256, seed=0, seed2=0): """Init RandomChoiceWithMask""" - validator.check_type("count", count, [int]) - validator.check_integer("count", count, 0, Rel.GT) - validator.check_type('seed', seed, [int]) - validator.check_type('seed2', seed2, [int]) + validator.check_value_type("count", count, [int], self.name) + validator.check_integer("count", count, 0, Rel.GT, self.name) + validator.check_value_type('seed', seed, [int], self.name) + validator.check_value_type('seed2', seed2, [int], self.name) def infer_shape(self, x_shape): - validator.check_shape_length("input_x shape", len(x_shape), 1, Rel.GE) + validator.check_integer("input_x rank", len(x_shape), 1, Rel.GE, self.name) return ([self.count, len(x_shape)], [self.count]) def infer_dtype(self, x_dtype): - validator.check_subclass('x_dtype', x_dtype, mstype.tensor) - validator.check_typename('x_dtype', x_dtype, [mstype.bool_]) + validator.check_tensor_type_same({'x': x_dtype}, [mstype.bool_], self.name) return (mstype.int32, mstype.bool_) From 2891f0d20df3363fad0911aa3936b708440ea078 Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Tue, 21 Apr 2020 10:19:16 +0800 Subject: [PATCH 096/142] gpu dynamic memory pool supports multi-allReduce --- .../ccsrc/device/gpu/gpu_kernel_runtime.cc | 122 ++++-------------- .../ccsrc/device/gpu/gpu_kernel_runtime.h | 3 - .../ccsrc/device/gpu/gpu_memory_manager.cc | 4 + .../ccsrc/device/gpu/gpu_memory_manager.h | 2 + mindspore/ccsrc/device/kernel_runtime.h | 2 +- mindspore/ccsrc/device/memory_manager.cc | 23 ++++ mindspore/ccsrc/device/memory_manager.h | 4 + .../kernel/gpu/arrays/transpose_gpu_kernel.h | 2 +- .../kernel/gpu/cuda_impl/unary_op_impl.cu | 16 +++ .../kernel/gpu/cuda_impl/unary_op_impl.cuh | 3 + .../kernel/gpu/math/unary_op_gpu_kernel.h | 1 + .../mem_reuse/mem_dynamic_allocator.cc | 31 +++++ .../mem_reuse/mem_dynamic_allocator.h | 2 + .../ccsrc/pre_activate/mem_reuse/mem_reuse.cc | 8 -- tests/st/nccl/test_nccl_all_reduce_op.py | 2 +- 15 files changed, 117 insertions(+), 108 deletions(-) diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 11b8bdc162..5dd4facb25 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -111,7 +111,8 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(mem_manager_); mem_manager_->ResetDynamicMemory(); - AssignStaticMemory(graph); + AssignStaticMemoryInput(graph); + AssignStaticMemoryValueNode(graph); bool is_enable_dynamic_mem = context_ptr->enable_dynamic_mem_pool(); if (is_enable_dynamic_mem) { // Use the dynamic memory pool. @@ -181,7 +182,7 @@ void GPUKernelRuntime::InitKernelOutputAddress(const session::KernelGraph *graph bool GPUKernelRuntime::LaunchKernelDynamic(const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); auto graph_id = graph->graph_id(); - // The inputs and outputs memory of communication kernel are special, so separate processing. + // The inputs and outputs memory of communication kernel need be continuous, so separate processing. AllocCommunicationOpDynamicRes(graph); auto &kernels = graph->execution_order(); @@ -229,15 +230,12 @@ void GPUKernelRuntime::AllocKernelDynamicRes(const mindspore::kernel::KernelMod for (size_t i = 0; i < output_sizes.size(); ++i) { auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); - auto device_ptr = device_address->ptr_; - if (device_ptr == nullptr) { - device_ptr = mem_manager_->MallocMemFromMemPool(output_sizes[i]); - MS_EXCEPTION_IF_NULL(device_ptr); - device_address->ptr_ = device_ptr; + if (device_address->ptr_ == nullptr) { + mem_manager_->MallocMemFromMemPool(device_address, output_sizes[i]); } kernel::AddressPtr output = std::make_shared(); MS_EXCEPTION_IF_NULL(output); - output->addr = device_ptr; + output->addr = device_address->ptr_; output->size = output_sizes[i]; kernel_outputs->push_back(output); } @@ -267,7 +265,6 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph if (kernel_name == kAllReduceOpName) { AllocCommunicationOpInputDynamicRes(kernel); AllocCommunicationOpOutputDynamicRes(kernel); - return; } } } @@ -275,48 +272,30 @@ void GPUKernelRuntime::AllocCommunicationOpDynamicRes(const session::KernelGraph void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - // The reference count of communication kernel input is not 0. - if (communication_op_input_ref_count_ != 0) { - MS_LOG(ERROR) << "The reference count of communication kernel input is not 0."; - return; - } - - size_t total = 0; - std::vector> addr_size; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); // The inputs of communication kernel are not released. - if ((i == 0) && (device_address->ptr_ != nullptr)) { - MS_LOG(ERROR) << "The inputs of communication kernel are not released."; - return; + if (device_address->ptr_ != nullptr) { + MS_LOG(INFO) << "The inputs of communication kernel are not released."; + mem_manager_->FreeMemFromMemPool(device_address); } - auto output_size = device_address->size_; - total += output_size; - addr_size.emplace_back(device_address.get(), output_size); - } - - auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); - MS_EXCEPTION_IF_NULL(device_mem_ptr); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(device_mem_ptr); - communication_op_input_ref_count_++; - device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); + total_size += device_address->size_; + size_list.emplace_back(device_address->size_); + addr_list.emplace_back(device_address); } + mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel) { MS_EXCEPTION_IF_NULL(kernel); MS_EXCEPTION_IF_NULL(mem_manager_); - // The reference count of communication kernel output is not 0. - if (communication_op_output_ref_count_ != 0) { - MS_LOG(ERROR) << "The reference count of communication kernel output is not 0."; - return; - } - - size_t total = 0; - std::vector> addr_size; + size_t total_size = 0; + std::vector size_list; + DeviceAddressPtrList addr_list; auto kernel_mod = AnfAlgo::GetKernelMod(kernel); MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); @@ -324,22 +303,15 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); MS_EXCEPTION_IF_NULL(device_address); // The outputs of communication kernel are not released. - if ((i == 0) && (device_address->ptr_ != nullptr)) { - MS_LOG(ERROR) << "The outputs of communication kernel are not released."; - return; + if (device_address->ptr_ != nullptr) { + MS_LOG(INFO) << "The outputs of communication kernel are not released."; + mem_manager_->FreeMemFromMemPool(device_address); } - total += output_sizes[i]; - addr_size.emplace_back(device_address.get(), output_sizes[i]); - } - - auto device_mem_ptr = mem_manager_->MallocMemFromMemPool(total); - MS_EXCEPTION_IF_NULL(device_mem_ptr); - for (const auto &iter : addr_size) { - MS_EXCEPTION_IF_NULL(iter.first); - iter.first->set_ptr(device_mem_ptr); - communication_op_output_ref_count_++; - device_mem_ptr = AddressOffset(device_mem_ptr, iter.second); + total_size += output_sizes[i]; + size_list.emplace_back(output_sizes[i]); + addr_list.emplace_back(device_address); } + mem_manager_->MallocContinuousMemFromMemPool(addr_list, total_size, size_list); } void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, @@ -362,14 +334,10 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } kernel_ref_count_ptr->ref_count_dynamic_use_--; if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + mem_manager_->FreeMemFromMemPool(device_address); // Reset the reference count. kernel_ref_count_ptr->ref_count_dynamic_use_ = kernel_ref_count_ptr->ref_count_; - bool is_communication_op = false; - FreeCommunicationOpDynamicRes(kernel, i, &is_communication_op); - if (!is_communication_op) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); - mem_manager_->FreeMemFromMemPool(device_address); - } } } // Free the output of kernel, if output has no reference. @@ -393,40 +361,6 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, } } } - -void GPUKernelRuntime::FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, - bool *is_communication_op) { - MS_EXCEPTION_IF_NULL(kernel); - MS_EXCEPTION_IF_NULL(mem_manager_); - // The inputs memory of communication kernel is one piece memory, need release together. - if (AnfAlgo::GetCNodeName(kernel) == kAllReduceOpName) { - communication_op_input_ref_count_--; - if (communication_op_input_ref_count_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, 0); - mem_manager_->FreeMemFromMemPool(device_address); - } - *is_communication_op = true; - return; - } - - auto cnode = kernel->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << cnode->inputs().size() - 1 - << "."; - } - auto input_node = cnode->input(input_idx + 1); - auto kernel_input = AnfAlgo::VisitKernel(input_node, 0); - // The outputs memory of communication kernel is one piece memory, need release together. - if (AnfAlgo::GetCNodeName(kernel_input.first) == kAllReduceOpName) { - communication_op_output_ref_count_--; - if (communication_op_output_ref_count_ == 0) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_input.first, 0); - mem_manager_->FreeMemFromMemPool(device_address); - } - *is_communication_op = true; - } -} } // namespace gpu } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h index e0eb2dc3f1..33d4b4be70 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.h @@ -60,9 +60,6 @@ class GPUKernelRuntime : public KernelRuntime { void AllocCommunicationOpOutputDynamicRes(const mindspore::AnfNodePtr &kernel); void FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, const AddressPtrList &kernel_workspaces, uint32_t graph_id); - void FreeCommunicationOpDynamicRes(const mindspore::AnfNodePtr &kernel, size_t input_idx, bool *is_communication_op); - size_t communication_op_input_ref_count_{0}; - size_t communication_op_output_ref_count_{0}; std::unordered_map mem_reuse_util_map_; }; MS_REG_KERNEL_RUNTIME(kGPUDevice, GPUKernelRuntime); diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc index 8bb65963d8..6e81130b9c 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc +++ b/mindspore/ccsrc/device/gpu/gpu_memory_manager.cc @@ -29,6 +29,10 @@ void GPUMemoryManager::FreeMemFromMemPool(void *device_ptr) { GPUMemoryAllocator::GetInstance().FreeTensorMem(device_ptr); } +std::vector GPUMemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + return GPUMemoryAllocator::GetInstance().AllocContinuousTensorMem(total_size, size_list); +} + void GPUMemoryManager::MallocDeviceMemory() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); diff --git a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h index cc5dac2a5e..c79fb9cc22 100644 --- a/mindspore/ccsrc/device/gpu/gpu_memory_manager.h +++ b/mindspore/ccsrc/device/gpu/gpu_memory_manager.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_GPU_GPU_MEMORY_MANAGER_H_ +#include #include "device/memory_manager.h" namespace mindspore { namespace device { @@ -30,6 +31,7 @@ class GPUMemoryManager : public MemoryManager { void *MallocMemFromMemPool(size_t size) override; void FreeMemFromMemPool(void *device_ptr) override; + std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); protected: uint8_t *MallocStaticMem(size_t size, bool communication_mem) override; diff --git a/mindspore/ccsrc/device/kernel_runtime.h b/mindspore/ccsrc/device/kernel_runtime.h index 8f4f769f55..b15cb31e17 100644 --- a/mindspore/ccsrc/device/kernel_runtime.h +++ b/mindspore/ccsrc/device/kernel_runtime.h @@ -67,6 +67,7 @@ class KernelRuntime { TypeId type_id) = 0; virtual bool SyncStream() = 0; void AssignStaticMemory(session::KernelGraph *graph); + void AssignStaticMemoryValueNode(session::KernelGraph *graph); void AssignDynamicMemory(session::KernelGraph *graph); void ReuseAssignDynamicMemory(session::KernelGraph *graph); void AssignNodeOutputMem(int flag, const AnfNodePtr &node, int index); @@ -81,7 +82,6 @@ class KernelRuntime { private: void AssignStaticMemoryOutput(const session::KernelGraph *graph); - void AssignStaticMemoryValueNode(session::KernelGraph *graph); void GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod, const AnfNodePtr &kernel, AddressPtrList *kernel_inputs, AddressPtrList *kernel_workspaces, AddressPtrList *kernel_outputs); bool LaunchKernelMod(const session::KernelGraph &graph); diff --git a/mindspore/ccsrc/device/memory_manager.cc b/mindspore/ccsrc/device/memory_manager.cc index 2fad5fc10e..dce54495b0 100644 --- a/mindspore/ccsrc/device/memory_manager.cc +++ b/mindspore/ccsrc/device/memory_manager.cc @@ -167,5 +167,28 @@ void MemoryManager::FreeMemFromMemPool(void *device_ptr) { MS_LOG(ERROR) << "FreeMemFromMemPool device_ptr is null."; } } + +void MemoryManager::MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list) { + auto device_ptr_list = MallocContinuousMemFromMemPool(total_size, size_list); + if (addr_list.size() != device_ptr_list.size()) { + MS_LOG(EXCEPTION) << "The size of device list is not equal to the size of address list."; + } + for (size_t i = 0; i < addr_list.size(); i++) { + MS_EXCEPTION_IF_NULL(device_ptr_list[i]); + MS_EXCEPTION_IF_NULL(addr_list[i]); + addr_list[i]->ptr_ = device_ptr_list[i]; + addr_list[i]->from_mem_pool_ = true; + } +} + +std::vector MemoryManager::MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list) { + if (total_size == 0) { + MS_LOG(ERROR) << "MallocContinuousMemFromMemPool total_size is 0."; + } + std::vector device_ptr_list; + device_ptr_list.emplace_back(nullptr); + return device_ptr_list; +} } // namespace device } // namespace mindspore diff --git a/mindspore/ccsrc/device/memory_manager.h b/mindspore/ccsrc/device/memory_manager.h index c90ffc380e..dae0861506 100644 --- a/mindspore/ccsrc/device/memory_manager.h +++ b/mindspore/ccsrc/device/memory_manager.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ #define MINDSPORE_MINDSPORE_CCSRC_DEVICE_MEMORY_MANAGER_H_ #include +#include #include "pre_activate/mem_reuse/mem_reuse.h" #include "pre_activate/mem_reuse/mem_reuse_allocator.h" namespace mindspore { @@ -49,6 +50,9 @@ class MemoryManager { virtual void *MallocMemFromMemPool(size_t size); virtual void FreeMemFromMemPool(const DeviceAddressPtr address); virtual void FreeMemFromMemPool(void *device_ptr); + virtual void MallocContinuousMemFromMemPool(const DeviceAddressPtrList addr_list, size_t total_size, + std::vector size_list); + virtual std::vector MallocContinuousMemFromMemPool(size_t total_size, std::vector size_list); size_t GetCommonAlignSize(size_t input_size) const; size_t GetCommunicationAlignSize(size_t input_size) const; diff --git a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h index 198e8687fc..1c9cf925ea 100644 --- a/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/arrays/transpose_gpu_kernel.h @@ -44,7 +44,7 @@ class TransposeGpuFwdKernel : public GpuKernel { "cudaMemcpyAsync input_shape failed"); CHECK_CUDA_RET_WITH_EXCEPT(cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcphalfyAsync input_axis failed"); + "cudaMemcpyAsync input_axis failed"); int size = SizeToInt(input_size_ / sizeof(T)); CalTranspose(size, input, input_shape, input_axis, SizeToInt(shape_size_), output, reinterpret_cast(stream_ptr)); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu index 6022485251..5e7a25b8e6 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu @@ -60,6 +60,14 @@ __global__ void SquareKernel(T *input, T *output, size_t count) { return; } template +__global__ void ZeroslikeKernel(T *output, size_t count) { + T zero = 0.0; + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = zero; + } + return; +} +template void Exponential(T *input, T *output, size_t count, cudaStream_t cuda_stream) { ExponentialKernel<<>>(input, output, count); return; @@ -84,13 +92,21 @@ void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) { SquareKernel<<>>(input, output, count); return; } +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { + ZeroslikeKernel<<>>(output, count); + return; +} + template void Exponential(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Negative(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Square(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(float *output, size_t count, cudaStream_t cuda_stream); template void Exponential(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Negative(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Square(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh index f303c73d29..8ba9cb4a52 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh @@ -28,4 +28,7 @@ template void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); + #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h index 5b2414f8f1..d8fea7370b 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h @@ -81,6 +81,7 @@ class UnaryOpGpuKernel : public GpuKernel { break; } case UNARY_OP_ZEROSLIKE: { + Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; } default: { diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc index c9ef381f16..b7280f52ae 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.cc @@ -36,6 +36,37 @@ DeviceMemPtr DynamicMemPoolBestFit::AllocTensorMem(size_t size) { return device_addr; } +std::vector DynamicMemPoolBestFit::AllocContinuousTensorMem(size_t total_size, + std::vector size_list) { + // Pre-alloc the one whole piece memory. + auto device_addr = AllocTensorMem(total_size); + MS_EXCEPTION_IF_NULL(device_addr); + // Remove the pre-alloc memory. + auto mem_block = FindMemBlock(device_addr); + MS_EXCEPTION_IF_NULL(mem_block); + auto iter = mem_block->block_all_mem_buf_map_.find(device_addr); + if (iter == mem_block->block_all_mem_buf_map_.end()) { + MS_LOG(EXCEPTION) << "Can't find the device address[" << device_addr << "]."; + } + auto mem_buf = iter->second; + MS_EXCEPTION_IF_NULL(mem_buf); + auto rest_size = mem_buf->size_ - total_size; + (void)mem_block->block_all_mem_buf_map_.erase(iter); + // Split the pre-alloc memory into continuous memory by the size list. + DynamicMemBufPtr continuous_mem_buf; + std::vector device_addr_list; + auto buf_addr = device_addr; + for (size_t i = 0; i < size_list.size(); i++) { + continuous_mem_buf = std::make_shared(buf_addr, kMemBufUsed, size_list[i]); + (void)mem_block->block_all_mem_buf_map_.emplace(buf_addr, continuous_mem_buf); + device_addr_list.emplace_back(buf_addr); + buf_addr = AddressOffset(buf_addr, size_list[i]); + } + // Update the size of the last memory buf. + continuous_mem_buf->size_ += rest_size; + return device_addr_list; +} + size_t DynamicMemPoolBestFit::AlignMemorySize(size_t size) const { if (size == 0) { return DYNAMIC_MEM_ALIGN_SIZE; diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h index c628756070..07efa267aa 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_dynamic_allocator.h @@ -79,6 +79,8 @@ class DynamicMemPoolBestFit { virtual ~DynamicMemPoolBestFit(); // The main program entry of memory alloc. DeviceMemPtr AllocTensorMem(size_t size); + // The main program entry of continuous memory alloc. + std::vector AllocContinuousTensorMem(size_t total_size, std::vector size_list); // The main program entry of memory free. void FreeTensorMem(const DeviceMemPtr device_addr); // Release the real device memory. diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index d25b60003f..952dfe97e4 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -162,10 +162,6 @@ void MemReuseUtil::SetInputMap(const CNodePtr &kernel, KernelDef *kernel_def_ptr if (iter == kernel_def_ptr->inputs_.end()) { kernel_def_ptr->inputs_[key].push_back(ref_ptr); } else { - if (std::any_of(iter->second.begin(), iter->second.end(), - [ref_ptr](const KernelRefCountPtr &it) { return (it.get() == ref_ptr.get()); })) { - break; - } iter->second.push_back(ref_ptr); } } @@ -185,10 +181,6 @@ void MemReuseUtil::SetOutputMap(const CNodePtr &kernel, KernelDef *kernel_def_pt if (iter == kernel_def_ptr->outputs_.end()) { kernel_def_ptr->outputs_[key].push_back(kernel_ref); } else { - if (std::any_of(iter->second.begin(), iter->second.end(), - [kernel_ref](const KernelRefCountPtr &it) { return (it == kernel_ref); })) { - break; - } iter->second.push_back(kernel_ref); } } diff --git a/tests/st/nccl/test_nccl_all_reduce_op.py b/tests/st/nccl/test_nccl_all_reduce_op.py index 7c2e579463..3ba8b219e4 100644 --- a/tests/st/nccl/test_nccl_all_reduce_op.py +++ b/tests/st/nccl/test_nccl_all_reduce_op.py @@ -20,7 +20,7 @@ import mindspore.context as context from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size -context.set_context(mode=context.GRAPH_MODE, device_target='GPU', enable_dynamic_memory=False) +context.set_context(mode=context.GRAPH_MODE, device_target='GPU') init('nccl') rank = get_rank() From 3f5eaa5e07edba79445c3dc68737a4baee77c9c1 Mon Sep 17 00:00:00 2001 From: dinghao Date: Wed, 22 Apr 2020 11:45:48 +0800 Subject: [PATCH 097/142] modify tensor copy construct --- mindspore/ccsrc/ir/meta_tensor.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/ir/meta_tensor.cc b/mindspore/ccsrc/ir/meta_tensor.cc index 5bb9ae3c06..fe41abcef4 100644 --- a/mindspore/ccsrc/ir/meta_tensor.cc +++ b/mindspore/ccsrc/ir/meta_tensor.cc @@ -164,8 +164,11 @@ Tensor::Tensor(const py::float_ &input, const TypePtr &data_type) { init(py::arr Tensor::Tensor(const py::int_ &input, const TypePtr &data_type) { init(py::array(input), data_type); } Tensor::Tensor(const Tensor &tensor, const TypePtr &data_type) - : MetaTensor(tensor), device_address_(tensor.device_address()) { + : MetaTensor(tensor), dirty_(tensor.dirty_), device_address_(tensor.device_address_) { init(tensor.data_, data_type); + if (device_address_ != nullptr) { + (void)data_sync(); + } } Tensor &Tensor::operator=(const Tensor &tensor) { From b50bfbf7d014c779a4484ea84b181117ff295868 Mon Sep 17 00:00:00 2001 From: maoweiyong Date: Wed, 22 Apr 2020 15:53:33 +0800 Subject: [PATCH 098/142] fix hccl get data type bug --- mindspore/ccsrc/kernel/hccl/hcom_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/ccsrc/kernel/hccl/hcom_util.cc b/mindspore/ccsrc/kernel/hccl/hcom_util.cc index 8e5f9cb7e6..d1c0a30113 100644 --- a/mindspore/ccsrc/kernel/hccl/hcom_util.cc +++ b/mindspore/ccsrc/kernel/hccl/hcom_util.cc @@ -49,7 +49,7 @@ bool HcomUtil::GetHcomDataType(const AnfNodePtr &anf_node, vector Date: Wed, 22 Apr 2020 15:58:50 +0800 Subject: [PATCH 099/142] Rename davinci to ascend in ops st test module --- tests/st/ops/{davinci => ascend}/test_add.py | 0 tests/st/ops/{davinci => ascend}/test_addn.py | 0 .../test_apply_momentum.py | 88 ++++++------ .../st/ops/{davinci => ascend}/test_argmax.py | 0 .../{davinci => ascend}/test_biasAddGrad.py | 84 +++++------ .../{davinci => ascend}/test_bias_add_grad.py | 78 +++++------ tests/st/ops/{davinci => ascend}/test_conv.py | 0 .../test_conv2dGradFilter.py | 0 .../ops/{davinci => ascend}/test_conv_grad.py | 0 .../st/ops/{davinci => ascend}/test_dense.py | 0 .../{davinci => ascend}/test_dense_grad.py | 0 .../test_drop_out_gen_mask.py | 88 ++++++------ .../{davinci => ascend}/test_equal_count.py | 0 .../test_full_connection.py | 0 .../test_fused_batchnorm.py | 0 .../test_fused_batchnorm_grad.py | 0 .../test_image_gradients.py | 0 .../st/ops/{davinci => ascend}/test_matmul.py | 0 .../ops/{davinci => ascend}/test_maxpool.py | 0 .../{davinci => ascend}/test_maxpool_grad.py | 0 .../test_maxpool_with_argmax.py | 0 .../test_maxpool_with_argmax_grad.py | 0 tests/st/ops/{davinci => ascend}/test_relu.py | 0 .../ops/{davinci => ascend}/test_relu_grad.py | 0 .../ops/{davinci => ascend}/test_reshape.py | 0 .../{davinci => ascend}/test_simplemean.py | 0 .../test_simplemean_grad.py | 0 .../ops/{davinci => ascend}/test_softmax.py | 0 ...est_sparseSoftmaxCrossEntropyWithLogits.py | 0 ...parse_softmax_cross_entropy_with_logits.py | 0 ..._softmax_cross_entropy_with_logits_grad.py | 0 .../test_tbe_ops/test_AssignAdd.py | 0 .../test_tbe_ops/test_AssignSub.py | 0 .../test_tbe_ops/test_ReduceMean.py | 0 .../test_tbe_ops/test_add.py | 0 .../test_tbe_ops/test_addn.py | 0 .../test_tbe_ops/test_apply_adam.py | 0 .../test_tbe_ops/test_apply_momentum.py | 0 .../test_tbe_ops/test_batchmatmul.py | 0 .../test_tbe_ops/test_batchnorm.py | 0 .../test_tbe_ops/test_batchnorm_grad.py | 0 .../test_tbe_ops/test_bias_add.py | 0 .../test_tbe_ops/test_bias_add_grad.py | 0 .../test_tbe_ops/test_concat.py | 0 .../test_tbe_ops/test_conv.py | 0 .../test_conv2d_backprop_filter.py | 0 .../test_conv2d_backprop_input.py | 0 .../test_tbe_ops/test_dropout_do_mask.py | 0 .../test_tbe_ops/test_gelu.py | 0 .../test_tbe_ops/test_gelu_grad_sens.py | 0 .../test_tbe_ops/test_greater.py | 100 +++++++------- .../test_tbe_ops/test_layernorm.py | 110 +++++++-------- .../test_tbe_ops/test_layernorm_grad.py | 130 +++++++++--------- .../test_tbe_ops/test_less.py | 0 .../test_tbe_ops/test_less_equal.py | 0 .../test_tbe_ops/test_logical_and.py | 78 +++++------ .../test_tbe_ops/test_logical_not.py | 76 +++++----- .../test_tbe_ops/test_logical_or.py | 78 +++++------ .../test_tbe_ops/test_matmul.py | 0 .../test_tbe_ops/test_matmul_failed.py | 0 .../test_tbe_ops/test_maximum.py | 0 .../test_tbe_ops/test_maximum_grad.py | 0 .../test_tbe_ops/test_maxpool.py | 0 .../test_tbe_ops/test_maxpool_grad.py | 0 .../test_tbe_ops/test_minimum.py | 0 .../test_tbe_ops/test_minimum_grad.py | 0 .../test_tbe_ops/test_mul.py | 0 .../test_npu_alloc_float_status.py | 0 .../test_npu_clear_float_status.py | 0 .../test_tbe_ops/test_npu_get_float_status.py | 0 .../test_tbe_ops/test_pad.py | 0 .../test_tbe_ops/test_pow.py | 0 .../test_tbe_ops/test_realdiv.py | 0 .../test_tbe_ops/test_reciprocal.py | 0 .../test_tbe_ops/test_relu.py | 0 .../test_tbe_ops/test_relu_grad.py | 0 .../test_resize_nearest_neighbor.py | 0 .../test_resize_nearest_neighbor_grad.py | 0 .../test_tbe_ops/test_scatter_nd.py | 0 .../test_tbe_ops/test_select.py | 0 .../test_tbe_ops/test_sigmoid.py | 0 .../test_sigmoid_cross_entropy_with_logits.py | 0 ..._sigmoid_cross_entropy_with_logits_grad.py | 0 .../test_tbe_ops/test_sigmoid_grad.py | 0 .../test_tbe_ops/test_slice.py | 0 .../test_tbe_ops/test_smooth_l1_loss.py | 0 .../test_tbe_ops/test_smooth_l1_loss_grad.py | 0 .../test_tbe_ops/test_softmax.py | 0 .../test_softmax_cross_entropy_with_logits.py | 0 .../test_tbe_ops/test_split.py | 0 .../test_tbe_ops/test_sqrt.py | 0 .../test_tbe_ops/test_square.py | 0 .../test_tbe_ops/test_stridedslice.py | 0 .../test_tbe_ops/test_stridedslice_grad.py | 0 .../test_tbe_ops/test_sub.py | 0 .../test_tbe_ops/test_tanh.py | 0 .../test_tbe_ops/test_tanh_grad.py | 0 .../test_tbe_ops/test_tile.py | 0 .../test_tbe_ops/test_topk.py | 0 .../test_tbe_ops/test_transpose_d.py | 0 .../test_tbe_ops/test_unsorted_segment_sum.py | 0 .../{davinci => ascend}/test_tdt_data_ms.py | 0 102 files changed, 455 insertions(+), 455 deletions(-) rename tests/st/ops/{davinci => ascend}/test_add.py (100%) rename tests/st/ops/{davinci => ascend}/test_addn.py (100%) rename tests/st/ops/{davinci => ascend}/test_apply_momentum.py (97%) rename tests/st/ops/{davinci => ascend}/test_argmax.py (100%) rename tests/st/ops/{davinci => ascend}/test_biasAddGrad.py (97%) rename tests/st/ops/{davinci => ascend}/test_bias_add_grad.py (97%) rename tests/st/ops/{davinci => ascend}/test_conv.py (100%) rename tests/st/ops/{davinci => ascend}/test_conv2dGradFilter.py (100%) rename tests/st/ops/{davinci => ascend}/test_conv_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_dense.py (100%) rename tests/st/ops/{davinci => ascend}/test_dense_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_drop_out_gen_mask.py (97%) rename tests/st/ops/{davinci => ascend}/test_equal_count.py (100%) rename tests/st/ops/{davinci => ascend}/test_full_connection.py (100%) rename tests/st/ops/{davinci => ascend}/test_fused_batchnorm.py (100%) rename tests/st/ops/{davinci => ascend}/test_fused_batchnorm_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_image_gradients.py (100%) rename tests/st/ops/{davinci => ascend}/test_matmul.py (100%) rename tests/st/ops/{davinci => ascend}/test_maxpool.py (100%) rename tests/st/ops/{davinci => ascend}/test_maxpool_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_maxpool_with_argmax.py (100%) rename tests/st/ops/{davinci => ascend}/test_maxpool_with_argmax_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_relu.py (100%) rename tests/st/ops/{davinci => ascend}/test_relu_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_reshape.py (100%) rename tests/st/ops/{davinci => ascend}/test_simplemean.py (100%) rename tests/st/ops/{davinci => ascend}/test_simplemean_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_softmax.py (100%) rename tests/st/ops/{davinci => ascend}/test_sparseSoftmaxCrossEntropyWithLogits.py (100%) rename tests/st/ops/{davinci => ascend}/test_sparse_softmax_cross_entropy_with_logits.py (100%) rename tests/st/ops/{davinci => ascend}/test_sparse_softmax_cross_entropy_with_logits_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_AssignAdd.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_AssignSub.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_ReduceMean.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_add.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_addn.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_apply_adam.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_apply_momentum.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_batchmatmul.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_batchnorm.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_batchnorm_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_bias_add.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_bias_add_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_concat.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_conv.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_conv2d_backprop_filter.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_conv2d_backprop_input.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_dropout_do_mask.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_gelu.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_gelu_grad_sens.py (100%) mode change 100755 => 100644 rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_greater.py (95%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_layernorm.py (97%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_layernorm_grad.py (97%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_less.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_less_equal.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_logical_and.py (97%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_logical_not.py (97%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_logical_or.py (97%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_matmul.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_matmul_failed.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_maximum.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_maximum_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_maxpool.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_maxpool_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_minimum.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_minimum_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_mul.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_npu_alloc_float_status.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_npu_clear_float_status.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_npu_get_float_status.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_pad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_pow.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_realdiv.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_reciprocal.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_relu.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_relu_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_resize_nearest_neighbor.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_resize_nearest_neighbor_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_scatter_nd.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_select.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_sigmoid.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_sigmoid_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_slice.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_smooth_l1_loss.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_smooth_l1_loss_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_softmax.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_softmax_cross_entropy_with_logits.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_split.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_sqrt.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_square.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_stridedslice.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_stridedslice_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_sub.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_tanh.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_tanh_grad.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_tile.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_topk.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_transpose_d.py (100%) rename tests/st/ops/{davinci => ascend}/test_tbe_ops/test_unsorted_segment_sum.py (100%) rename tests/st/ops/{davinci => ascend}/test_tdt_data_ms.py (100%) diff --git a/tests/st/ops/davinci/test_add.py b/tests/st/ops/ascend/test_add.py similarity index 100% rename from tests/st/ops/davinci/test_add.py rename to tests/st/ops/ascend/test_add.py diff --git a/tests/st/ops/davinci/test_addn.py b/tests/st/ops/ascend/test_addn.py similarity index 100% rename from tests/st/ops/davinci/test_addn.py rename to tests/st/ops/ascend/test_addn.py diff --git a/tests/st/ops/davinci/test_apply_momentum.py b/tests/st/ops/ascend/test_apply_momentum.py similarity index 97% rename from tests/st/ops/davinci/test_apply_momentum.py rename to tests/st/ops/ascend/test_apply_momentum.py index 885356ce48..e20c4f4746 100644 --- a/tests/st/ops/davinci/test_apply_momentum.py +++ b/tests/st/ops/ascend/test_apply_momentum.py @@ -1,44 +1,44 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.apply_momentum = P.ApplyMomentum(gradient_scale=1024.0) - self.variable = Parameter(initializer( - 'normal', [2, 3, 3, 4]), name='variable') - self.accumulation = Parameter(initializer( - 'normal', [2, 3, 3, 4]), name='accumulation') - self.learning_rate = Parameter(initializer( - 'normal', [1, ]), name='learning_rate') - self.gradient = Parameter(initializer( - 'normal', [2, 3, 3, 4]), name='gradient') - self.momentum = Parameter(initializer( - 'normal', [1, ]), name='momentum') - def construct(self): - return self.apply_momentum(self.variable, self.accumulation, self.learning_rate, self.gradient, self.momentum) - -def test_net(): - apply_momentum = Net() - output = apply_momentum() - print(output.asnumpy()) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.apply_momentum = P.ApplyMomentum(gradient_scale=1024.0) + self.variable = Parameter(initializer( + 'normal', [2, 3, 3, 4]), name='variable') + self.accumulation = Parameter(initializer( + 'normal', [2, 3, 3, 4]), name='accumulation') + self.learning_rate = Parameter(initializer( + 'normal', [1, ]), name='learning_rate') + self.gradient = Parameter(initializer( + 'normal', [2, 3, 3, 4]), name='gradient') + self.momentum = Parameter(initializer( + 'normal', [1, ]), name='momentum') + def construct(self): + return self.apply_momentum(self.variable, self.accumulation, self.learning_rate, self.gradient, self.momentum) + +def test_net(): + apply_momentum = Net() + output = apply_momentum() + print(output.asnumpy()) diff --git a/tests/st/ops/davinci/test_argmax.py b/tests/st/ops/ascend/test_argmax.py similarity index 100% rename from tests/st/ops/davinci/test_argmax.py rename to tests/st/ops/ascend/test_argmax.py diff --git a/tests/st/ops/davinci/test_biasAddGrad.py b/tests/st/ops/ascend/test_biasAddGrad.py similarity index 97% rename from tests/st/ops/davinci/test_biasAddGrad.py rename to tests/st/ops/ascend/test_biasAddGrad.py index 29b63ac336..f2e8f7a9bc 100644 --- a/tests/st/ops/davinci/test_biasAddGrad.py +++ b/tests/st/ops/ascend/test_biasAddGrad.py @@ -1,42 +1,42 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.bias_add_grad = G.BiasAddGrad() - #self.dout = Parameter(initializer( - #'normal', [2, 3, 3, 4]), name='dout') - - - @ms_function - def construct(self, dout): - return self.bias_add_grad(dout) - -dout = np.ones([2,3,4,4]).astype(np.float32) -bias_add_grad = Net() -output = bias_add_grad(Tensor(dout)) -expect_output = np.array([32.,32.,32.]).astype(np.float32) -assert np.all(output.asnumpy()==expect_output), "bias_add_grad execute failed, please check current code commit" -print(output.asnumpy()) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.bias_add_grad = G.BiasAddGrad() + #self.dout = Parameter(initializer( + #'normal', [2, 3, 3, 4]), name='dout') + + + @ms_function + def construct(self, dout): + return self.bias_add_grad(dout) + +dout = np.ones([2,3,4,4]).astype(np.float32) +bias_add_grad = Net() +output = bias_add_grad(Tensor(dout)) +expect_output = np.array([32.,32.,32.]).astype(np.float32) +assert np.all(output.asnumpy()==expect_output), "bias_add_grad execute failed, please check current code commit" +print(output.asnumpy()) diff --git a/tests/st/ops/davinci/test_bias_add_grad.py b/tests/st/ops/ascend/test_bias_add_grad.py similarity index 97% rename from tests/st/ops/davinci/test_bias_add_grad.py rename to tests/st/ops/ascend/test_bias_add_grad.py index e58d376e93..c6a51d8b3b 100644 --- a/tests/st/ops/davinci/test_bias_add_grad.py +++ b/tests/st/ops/ascend/test_bias_add_grad.py @@ -1,39 +1,39 @@ -# Copyright 2019 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -from mindspore.ops.operations import _grad_ops as G -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.bias_add_grad = G.BiasAddGrad() - - - @ms_function - def construct(self, dout): - return self.bias_add_grad(dout) - -def test_net(): - dout = np.random.rand(1, 1001).astype(np.float32) - bias_add_grad = Net() - output = bias_add_grad(dout) - print(output.asnumpy()) +# Copyright 2019 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.ops.operations import _grad_ops as G +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.bias_add_grad = G.BiasAddGrad() + + + @ms_function + def construct(self, dout): + return self.bias_add_grad(dout) + +def test_net(): + dout = np.random.rand(1, 1001).astype(np.float32) + bias_add_grad = Net() + output = bias_add_grad(dout) + print(output.asnumpy()) diff --git a/tests/st/ops/davinci/test_conv.py b/tests/st/ops/ascend/test_conv.py similarity index 100% rename from tests/st/ops/davinci/test_conv.py rename to tests/st/ops/ascend/test_conv.py diff --git a/tests/st/ops/davinci/test_conv2dGradFilter.py b/tests/st/ops/ascend/test_conv2dGradFilter.py similarity index 100% rename from tests/st/ops/davinci/test_conv2dGradFilter.py rename to tests/st/ops/ascend/test_conv2dGradFilter.py diff --git a/tests/st/ops/davinci/test_conv_grad.py b/tests/st/ops/ascend/test_conv_grad.py similarity index 100% rename from tests/st/ops/davinci/test_conv_grad.py rename to tests/st/ops/ascend/test_conv_grad.py diff --git a/tests/st/ops/davinci/test_dense.py b/tests/st/ops/ascend/test_dense.py similarity index 100% rename from tests/st/ops/davinci/test_dense.py rename to tests/st/ops/ascend/test_dense.py diff --git a/tests/st/ops/davinci/test_dense_grad.py b/tests/st/ops/ascend/test_dense_grad.py similarity index 100% rename from tests/st/ops/davinci/test_dense_grad.py rename to tests/st/ops/ascend/test_dense_grad.py diff --git a/tests/st/ops/davinci/test_drop_out_gen_mask.py b/tests/st/ops/ascend/test_drop_out_gen_mask.py similarity index 97% rename from tests/st/ops/davinci/test_drop_out_gen_mask.py rename to tests/st/ops/ascend/test_drop_out_gen_mask.py index 4d7c555219..ce7ebbfbe0 100644 --- a/tests/st/ops/davinci/test_drop_out_gen_mask.py +++ b/tests/st/ops/ascend/test_drop_out_gen_mask.py @@ -1,44 +1,44 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -import numpy as np -import mindspore.context as context -context.set_context(mode=context.GRAPH_MODE, - device_target="Ascend") - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.mask = P.DropoutGenMask(10, 28) - self.shape = P.Shape() - - def construct(self, x, y): - shape_x = self.shape(x) - return self.mask(shape_x, y) - - -x = np.ones([2, 4, 2, 2]).astype(np.int32) -y = np.array([1.0]).astype(np.float32) - - -def test_net(): - mask = Net() - tx, ty = Tensor(x), Tensor(y) - output = mask(tx, ty) - print(output.asnumpy()) - assert ([255, 255, 255, 255] == output.asnumpy()).all() +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context +context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mask = P.DropoutGenMask(10, 28) + self.shape = P.Shape() + + def construct(self, x, y): + shape_x = self.shape(x) + return self.mask(shape_x, y) + + +x = np.ones([2, 4, 2, 2]).astype(np.int32) +y = np.array([1.0]).astype(np.float32) + + +def test_net(): + mask = Net() + tx, ty = Tensor(x), Tensor(y) + output = mask(tx, ty) + print(output.asnumpy()) + assert ([255, 255, 255, 255] == output.asnumpy()).all() diff --git a/tests/st/ops/davinci/test_equal_count.py b/tests/st/ops/ascend/test_equal_count.py similarity index 100% rename from tests/st/ops/davinci/test_equal_count.py rename to tests/st/ops/ascend/test_equal_count.py diff --git a/tests/st/ops/davinci/test_full_connection.py b/tests/st/ops/ascend/test_full_connection.py similarity index 100% rename from tests/st/ops/davinci/test_full_connection.py rename to tests/st/ops/ascend/test_full_connection.py diff --git a/tests/st/ops/davinci/test_fused_batchnorm.py b/tests/st/ops/ascend/test_fused_batchnorm.py similarity index 100% rename from tests/st/ops/davinci/test_fused_batchnorm.py rename to tests/st/ops/ascend/test_fused_batchnorm.py diff --git a/tests/st/ops/davinci/test_fused_batchnorm_grad.py b/tests/st/ops/ascend/test_fused_batchnorm_grad.py similarity index 100% rename from tests/st/ops/davinci/test_fused_batchnorm_grad.py rename to tests/st/ops/ascend/test_fused_batchnorm_grad.py diff --git a/tests/st/ops/davinci/test_image_gradients.py b/tests/st/ops/ascend/test_image_gradients.py similarity index 100% rename from tests/st/ops/davinci/test_image_gradients.py rename to tests/st/ops/ascend/test_image_gradients.py diff --git a/tests/st/ops/davinci/test_matmul.py b/tests/st/ops/ascend/test_matmul.py similarity index 100% rename from tests/st/ops/davinci/test_matmul.py rename to tests/st/ops/ascend/test_matmul.py diff --git a/tests/st/ops/davinci/test_maxpool.py b/tests/st/ops/ascend/test_maxpool.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool.py rename to tests/st/ops/ascend/test_maxpool.py diff --git a/tests/st/ops/davinci/test_maxpool_grad.py b/tests/st/ops/ascend/test_maxpool_grad.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool_grad.py rename to tests/st/ops/ascend/test_maxpool_grad.py diff --git a/tests/st/ops/davinci/test_maxpool_with_argmax.py b/tests/st/ops/ascend/test_maxpool_with_argmax.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool_with_argmax.py rename to tests/st/ops/ascend/test_maxpool_with_argmax.py diff --git a/tests/st/ops/davinci/test_maxpool_with_argmax_grad.py b/tests/st/ops/ascend/test_maxpool_with_argmax_grad.py similarity index 100% rename from tests/st/ops/davinci/test_maxpool_with_argmax_grad.py rename to tests/st/ops/ascend/test_maxpool_with_argmax_grad.py diff --git a/tests/st/ops/davinci/test_relu.py b/tests/st/ops/ascend/test_relu.py similarity index 100% rename from tests/st/ops/davinci/test_relu.py rename to tests/st/ops/ascend/test_relu.py diff --git a/tests/st/ops/davinci/test_relu_grad.py b/tests/st/ops/ascend/test_relu_grad.py similarity index 100% rename from tests/st/ops/davinci/test_relu_grad.py rename to tests/st/ops/ascend/test_relu_grad.py diff --git a/tests/st/ops/davinci/test_reshape.py b/tests/st/ops/ascend/test_reshape.py similarity index 100% rename from tests/st/ops/davinci/test_reshape.py rename to tests/st/ops/ascend/test_reshape.py diff --git a/tests/st/ops/davinci/test_simplemean.py b/tests/st/ops/ascend/test_simplemean.py similarity index 100% rename from tests/st/ops/davinci/test_simplemean.py rename to tests/st/ops/ascend/test_simplemean.py diff --git a/tests/st/ops/davinci/test_simplemean_grad.py b/tests/st/ops/ascend/test_simplemean_grad.py similarity index 100% rename from tests/st/ops/davinci/test_simplemean_grad.py rename to tests/st/ops/ascend/test_simplemean_grad.py diff --git a/tests/st/ops/davinci/test_softmax.py b/tests/st/ops/ascend/test_softmax.py similarity index 100% rename from tests/st/ops/davinci/test_softmax.py rename to tests/st/ops/ascend/test_softmax.py diff --git a/tests/st/ops/davinci/test_sparseSoftmaxCrossEntropyWithLogits.py b/tests/st/ops/ascend/test_sparseSoftmaxCrossEntropyWithLogits.py similarity index 100% rename from tests/st/ops/davinci/test_sparseSoftmaxCrossEntropyWithLogits.py rename to tests/st/ops/ascend/test_sparseSoftmaxCrossEntropyWithLogits.py diff --git a/tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits.py b/tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits.py similarity index 100% rename from tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits.py rename to tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits.py diff --git a/tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits_grad.py b/tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits_grad.py similarity index 100% rename from tests/st/ops/davinci/test_sparse_softmax_cross_entropy_with_logits_grad.py rename to tests/st/ops/ascend/test_sparse_softmax_cross_entropy_with_logits_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_AssignAdd.py b/tests/st/ops/ascend/test_tbe_ops/test_AssignAdd.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_AssignAdd.py rename to tests/st/ops/ascend/test_tbe_ops/test_AssignAdd.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_AssignSub.py b/tests/st/ops/ascend/test_tbe_ops/test_AssignSub.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_AssignSub.py rename to tests/st/ops/ascend/test_tbe_ops/test_AssignSub.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_ReduceMean.py b/tests/st/ops/ascend/test_tbe_ops/test_ReduceMean.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_ReduceMean.py rename to tests/st/ops/ascend/test_tbe_ops/test_ReduceMean.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_add.py b/tests/st/ops/ascend/test_tbe_ops/test_add.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_add.py rename to tests/st/ops/ascend/test_tbe_ops/test_add.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_addn.py b/tests/st/ops/ascend/test_tbe_ops/test_addn.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_addn.py rename to tests/st/ops/ascend/test_tbe_ops/test_addn.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_apply_adam.py b/tests/st/ops/ascend/test_tbe_ops/test_apply_adam.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_apply_adam.py rename to tests/st/ops/ascend/test_tbe_ops/test_apply_adam.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_apply_momentum.py b/tests/st/ops/ascend/test_tbe_ops/test_apply_momentum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_apply_momentum.py rename to tests/st/ops/ascend/test_tbe_ops/test_apply_momentum.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_batchmatmul.py b/tests/st/ops/ascend/test_tbe_ops/test_batchmatmul.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_batchmatmul.py rename to tests/st/ops/ascend/test_tbe_ops/test_batchmatmul.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_batchnorm.py b/tests/st/ops/ascend/test_tbe_ops/test_batchnorm.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_batchnorm.py rename to tests/st/ops/ascend/test_tbe_ops/test_batchnorm.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_batchnorm_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_batchnorm_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_batchnorm_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_batchnorm_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_bias_add.py b/tests/st/ops/ascend/test_tbe_ops/test_bias_add.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_bias_add.py rename to tests/st/ops/ascend/test_tbe_ops/test_bias_add.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_bias_add_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_bias_add_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_bias_add_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_bias_add_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_concat.py b/tests/st/ops/ascend/test_tbe_ops/test_concat.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_concat.py rename to tests/st/ops/ascend/test_tbe_ops/test_concat.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_conv.py b/tests/st/ops/ascend/test_tbe_ops/test_conv.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_conv.py rename to tests/st/ops/ascend/test_tbe_ops/test_conv.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_filter.py b/tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_filter.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_filter.py rename to tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_filter.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_input.py b/tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_input.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_conv2d_backprop_input.py rename to tests/st/ops/ascend/test_tbe_ops/test_conv2d_backprop_input.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_dropout_do_mask.py b/tests/st/ops/ascend/test_tbe_ops/test_dropout_do_mask.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_dropout_do_mask.py rename to tests/st/ops/ascend/test_tbe_ops/test_dropout_do_mask.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_gelu.py b/tests/st/ops/ascend/test_tbe_ops/test_gelu.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_gelu.py rename to tests/st/ops/ascend/test_tbe_ops/test_gelu.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_gelu_grad_sens.py b/tests/st/ops/ascend/test_tbe_ops/test_gelu_grad_sens.py old mode 100755 new mode 100644 similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_gelu_grad_sens.py rename to tests/st/ops/ascend/test_tbe_ops/test_gelu_grad_sens.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_greater.py b/tests/st/ops/ascend/test_tbe_ops/test_greater.py similarity index 95% rename from tests/st/ops/davinci/test_tbe_ops/test_greater.py rename to tests/st/ops/ascend/test_tbe_ops/test_greater.py index 3976ad4301..b9dae700c2 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_greater.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_greater.py @@ -1,51 +1,51 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -from mindspore.ops import operations as P -from mindspore.nn import Cell -from mindspore.common.tensor import Tensor -from mindspore.train.model import Model -from mindspore import log as logger -from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - -class Greater(Cell): - def __init__(self): - super(Greater, self).__init__() - self.greater = P.Greater() - - def construct(self, inputa, inputb): - return self.greater(inputa, inputb) - -def me_greater(inputa, inputb): - net = Greater() - net.set_train() - model = Model(net) - - out = model.predict(inputa, inputb) - logger.info("Check input a: ") - logger.info(inputa) - logger.info("Check input b: ") - logger.info(inputb) - return out.asnumpy() - -@pytest.mark.ssd_tbe -def test_greater_2d_scalar0(): - a = np.random.randint(-5, 5, [8, 32]).astype(np.int32) - b = np.random.randint(-5, 5, [8, 32]).astype(np.int32) - out_me = me_greater(Tensor(a), Tensor(b)) - logger.info("Check me result:") +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +from mindspore.ops import operations as P +from mindspore.nn import Cell +from mindspore.common.tensor import Tensor +from mindspore.train.model import Model +from mindspore import log as logger +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Greater(Cell): + def __init__(self): + super(Greater, self).__init__() + self.greater = P.Greater() + + def construct(self, inputa, inputb): + return self.greater(inputa, inputb) + +def me_greater(inputa, inputb): + net = Greater() + net.set_train() + model = Model(net) + + out = model.predict(inputa, inputb) + logger.info("Check input a: ") + logger.info(inputa) + logger.info("Check input b: ") + logger.info(inputb) + return out.asnumpy() + +@pytest.mark.ssd_tbe +def test_greater_2d_scalar0(): + a = np.random.randint(-5, 5, [8, 32]).astype(np.int32) + b = np.random.randint(-5, 5, [8, 32]).astype(np.int32) + out_me = me_greater(Tensor(a), Tensor(b)) + logger.info("Check me result:") logger.info(out_me) \ No newline at end of file diff --git a/tests/st/ops/davinci/test_tbe_ops/test_layernorm.py b/tests/st/ops/ascend/test_tbe_ops/test_layernorm.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_layernorm.py rename to tests/st/ops/ascend/test_tbe_ops/test_layernorm.py index 586c02cc1e..f3e4e43958 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_layernorm.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_layernorm.py @@ -1,55 +1,55 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -from mindspore.nn import LayerNorm -from mindspore.common.tensor import Tensor -from mindspore.nn import Cell -from mindspore.train.model import Model -from mindspore import log as logger -import pytest -from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - -class Net(Cell): - def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): - super(Net, self).__init__() - self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) - - def construct(self, input): - x = self.layernorm(input) - return x - -def pt_me_layernorm(input_data, normalized_shape, gamma, beta, axis): - net = Net(normalized_shape, begin_norm_axis=axis, - begin_params_axis=axis, - gamma=Tensor(gamma), - beta=Tensor(beta)) - net.set_train() - model = Model(net) - out_me = model.predict(Tensor(input_data)) - logger.info("Check me result:") - logger.info(out_me.asnumpy()) - -@pytest.mark.lower_bs -def test_normal_layernorm_1_128_1024_axis_2(): - """ - 2 input[1, 128, 1024],normalized_shape=[128, 1024] - """ - input_data = np.random.randn(1, 128, 1024).astype(np.float32) - gamma = np.random.randn(1024).astype(np.float32) - gamma.fill(1.1) - beta = np.random.randn(1024).astype(np.float32) - beta.fill(0.1) - pt_me_layernorm(input_data, (1024, ), gamma, beta, 2) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from mindspore.nn import LayerNorm +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell +from mindspore.train.model import Model +from mindspore import log as logger +import pytest +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Net(Cell): + def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): + super(Net, self).__init__() + self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) + + def construct(self, input): + x = self.layernorm(input) + return x + +def pt_me_layernorm(input_data, normalized_shape, gamma, beta, axis): + net = Net(normalized_shape, begin_norm_axis=axis, + begin_params_axis=axis, + gamma=Tensor(gamma), + beta=Tensor(beta)) + net.set_train() + model = Model(net) + out_me = model.predict(Tensor(input_data)) + logger.info("Check me result:") + logger.info(out_me.asnumpy()) + +@pytest.mark.lower_bs +def test_normal_layernorm_1_128_1024_axis_2(): + """ + 2 input[1, 128, 1024],normalized_shape=[128, 1024] + """ + input_data = np.random.randn(1, 128, 1024).astype(np.float32) + gamma = np.random.randn(1024).astype(np.float32) + gamma.fill(1.1) + beta = np.random.randn(1024).astype(np.float32) + beta.fill(0.1) + pt_me_layernorm(input_data, (1024, ), gamma, beta, 2) diff --git a/tests/st/ops/davinci/test_tbe_ops/test_layernorm_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_layernorm_grad.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_layernorm_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_layernorm_grad.py index 9f8eefdb3f..5ae09886ce 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_layernorm_grad.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_layernorm_grad.py @@ -1,65 +1,65 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -from mindspore.nn import LayerNorm -from mindspore.common.tensor import Tensor -from mindspore.nn import Cell -from mindspore.ops.composite import GradOperation -from mindspore import log as logger -from mindspore import context -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - -class Grad(Cell): - def __init__(self, network): - super(Grad, self).__init__() - self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) - self.network = network - - def construct(self, input, output_grad,): - gout = self.grad(self.network)(input, output_grad) - return gout - -class Net(Cell): - def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): - super(Net, self).__init__() - self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) - - def construct(self, input): - x = self.layernorm(input) - return x - -def py_me_layernorm_grad(input_data, normalized_shape, gamma, beta, axis, gradients): - input_me = Tensor(input_data) - net_me = Grad(Net(normalized_shape, begin_norm_axis=axis, - begin_params_axis=axis, - gamma=Tensor(gamma), - beta=Tensor(beta))) - net_me.set_train() - out_pool_grad_me = Tensor(gradients) - out_grad = net_me(input_me, out_pool_grad_me) - logger.info("Check me result:") - logger.info(out_grad.asnumpy()) - -def test_normal_layernorm_grad_normalize_2d(): - """ - 1 input[1, 128, 1024],normalized_shape=[1024],element_affine=False - """ - input_data = np.ones([1, 128, 1024]).astype(np.float32) - gradients = np.ones((1, 128, 1024)).astype(np.float32) - gamma = np.random.randn(1024).astype(np.float32) - gamma.fill(1.1) - beta = np.random.randn(1024).astype(np.float32) - beta.fill(0.1) - py_me_layernorm_grad(input_data, (1024,), gamma, beta, 2, gradients) +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +from mindspore.nn import LayerNorm +from mindspore.common.tensor import Tensor +from mindspore.nn import Cell +from mindspore.ops.composite import GradOperation +from mindspore import log as logger +from mindspore import context +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Grad(Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True, sens_param=True) + self.network = network + + def construct(self, input, output_grad,): + gout = self.grad(self.network)(input, output_grad) + return gout + +class Net(Cell): + def __init__(self, input_shape, begin_norm_axis, begin_params_axis, gamma, beta): + super(Net, self).__init__() + self.layernorm = LayerNorm(input_shape, begin_norm_axis, begin_params_axis, gamma, beta) + + def construct(self, input): + x = self.layernorm(input) + return x + +def py_me_layernorm_grad(input_data, normalized_shape, gamma, beta, axis, gradients): + input_me = Tensor(input_data) + net_me = Grad(Net(normalized_shape, begin_norm_axis=axis, + begin_params_axis=axis, + gamma=Tensor(gamma), + beta=Tensor(beta))) + net_me.set_train() + out_pool_grad_me = Tensor(gradients) + out_grad = net_me(input_me, out_pool_grad_me) + logger.info("Check me result:") + logger.info(out_grad.asnumpy()) + +def test_normal_layernorm_grad_normalize_2d(): + """ + 1 input[1, 128, 1024],normalized_shape=[1024],element_affine=False + """ + input_data = np.ones([1, 128, 1024]).astype(np.float32) + gradients = np.ones((1, 128, 1024)).astype(np.float32) + gamma = np.random.randn(1024).astype(np.float32) + gamma.fill(1.1) + beta = np.random.randn(1024).astype(np.float32) + beta.fill(0.1) + py_me_layernorm_grad(input_data, (1024,), gamma, beta, 2, gradients) diff --git a/tests/st/ops/davinci/test_tbe_ops/test_less.py b/tests/st/ops/ascend/test_tbe_ops/test_less.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_less.py rename to tests/st/ops/ascend/test_tbe_ops/test_less.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_less_equal.py b/tests/st/ops/ascend/test_tbe_ops/test_less_equal.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_less_equal.py rename to tests/st/ops/ascend/test_tbe_ops/test_less_equal.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_logical_and.py b/tests/st/ops/ascend/test_tbe_ops/test_logical_and.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_logical_and.py rename to tests/st/ops/ascend/test_tbe_ops/test_logical_and.py index c9f180a56e..1df04b27d4 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_logical_and.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_logical_and.py @@ -1,39 +1,39 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.logical_and = P.LogicalAnd() - - @ms_function - def construct(self, x1, x2): - return self.logical_and(x1, x2) - -x1 = [True, True, False, False, True, True, False, False] -x2 = [True, False, False, True, True, False, False, True] -def test_net(): - logical_and = Net() - output = logical_and(Tensor(x1), Tensor(x2)) - print(x1) - print(x2) - print(output.asnumpy()) - +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.logical_and = P.LogicalAnd() + + @ms_function + def construct(self, x1, x2): + return self.logical_and(x1, x2) + +x1 = [True, True, False, False, True, True, False, False] +x2 = [True, False, False, True, True, False, False, True] +def test_net(): + logical_and = Net() + output = logical_and(Tensor(x1), Tensor(x2)) + print(x1) + print(x2) + print(output.asnumpy()) + diff --git a/tests/st/ops/davinci/test_tbe_ops/test_logical_not.py b/tests/st/ops/ascend/test_tbe_ops/test_logical_not.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_logical_not.py rename to tests/st/ops/ascend/test_tbe_ops/test_logical_not.py index 97e9caa5c9..5d13a48138 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_logical_not.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_logical_not.py @@ -1,38 +1,38 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.logical_not = P.LogicalNot() - - @ms_function - def construct(self, x1): - return self.logical_not(x1) - -x1 = [True, True, False, False, True, True, False, False] - -def test_net(): - logical_not = Net() - output = logical_not(Tensor(x1)) - print(x1) - print(output.asnumpy()) - +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.logical_not = P.LogicalNot() + + @ms_function + def construct(self, x1): + return self.logical_not(x1) + +x1 = [True, True, False, False, True, True, False, False] + +def test_net(): + logical_not = Net() + output = logical_not(Tensor(x1)) + print(x1) + print(output.asnumpy()) + diff --git a/tests/st/ops/davinci/test_tbe_ops/test_logical_or.py b/tests/st/ops/ascend/test_tbe_ops/test_logical_or.py similarity index 97% rename from tests/st/ops/davinci/test_tbe_ops/test_logical_or.py rename to tests/st/ops/ascend/test_tbe_ops/test_logical_or.py index e34d94c3e7..a2b7841c71 100644 --- a/tests/st/ops/davinci/test_tbe_ops/test_logical_or.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_logical_or.py @@ -1,39 +1,39 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import numpy as np -import mindspore.context as context -context.set_context(device_target="Ascend") -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.logical_or = P.LogicalOr() - - @ms_function - def construct(self, x1, x2): - return self.logical_or(x1, x2) - -x1 = [True, True, False, False, True, True, False, False] -x2 = [True, False, False, True, True, False, False, True] -def test_net(): - logical_or = Net() - output = logical_or(Tensor(x1), Tensor(x2)) - print(x1) - print(x2) - print(output.asnumpy()) - +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +context.set_context(device_target="Ascend") +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.logical_or = P.LogicalOr() + + @ms_function + def construct(self, x1, x2): + return self.logical_or(x1, x2) + +x1 = [True, True, False, False, True, True, False, False] +x2 = [True, False, False, True, True, False, False, True] +def test_net(): + logical_or = Net() + output = logical_or(Tensor(x1), Tensor(x2)) + print(x1) + print(x2) + print(output.asnumpy()) + diff --git a/tests/st/ops/davinci/test_tbe_ops/test_matmul.py b/tests/st/ops/ascend/test_tbe_ops/test_matmul.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_matmul.py rename to tests/st/ops/ascend/test_tbe_ops/test_matmul.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_matmul_failed.py b/tests/st/ops/ascend/test_tbe_ops/test_matmul_failed.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_matmul_failed.py rename to tests/st/ops/ascend/test_tbe_ops/test_matmul_failed.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maximum.py b/tests/st/ops/ascend/test_tbe_ops/test_maximum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maximum.py rename to tests/st/ops/ascend/test_tbe_ops/test_maximum.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maximum_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_maximum_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maximum_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_maximum_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maxpool.py b/tests/st/ops/ascend/test_tbe_ops/test_maxpool.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maxpool.py rename to tests/st/ops/ascend/test_tbe_ops/test_maxpool.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_maxpool_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_maxpool_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_maxpool_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_maxpool_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_minimum.py b/tests/st/ops/ascend/test_tbe_ops/test_minimum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_minimum.py rename to tests/st/ops/ascend/test_tbe_ops/test_minimum.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_minimum_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_minimum_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_minimum_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_minimum_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_mul.py b/tests/st/ops/ascend/test_tbe_ops/test_mul.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_mul.py rename to tests/st/ops/ascend/test_tbe_ops/test_mul.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_npu_alloc_float_status.py b/tests/st/ops/ascend/test_tbe_ops/test_npu_alloc_float_status.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_npu_alloc_float_status.py rename to tests/st/ops/ascend/test_tbe_ops/test_npu_alloc_float_status.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_npu_clear_float_status.py b/tests/st/ops/ascend/test_tbe_ops/test_npu_clear_float_status.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_npu_clear_float_status.py rename to tests/st/ops/ascend/test_tbe_ops/test_npu_clear_float_status.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_npu_get_float_status.py b/tests/st/ops/ascend/test_tbe_ops/test_npu_get_float_status.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_npu_get_float_status.py rename to tests/st/ops/ascend/test_tbe_ops/test_npu_get_float_status.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_pad.py b/tests/st/ops/ascend/test_tbe_ops/test_pad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_pad.py rename to tests/st/ops/ascend/test_tbe_ops/test_pad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_pow.py b/tests/st/ops/ascend/test_tbe_ops/test_pow.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_pow.py rename to tests/st/ops/ascend/test_tbe_ops/test_pow.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_realdiv.py b/tests/st/ops/ascend/test_tbe_ops/test_realdiv.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_realdiv.py rename to tests/st/ops/ascend/test_tbe_ops/test_realdiv.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_reciprocal.py b/tests/st/ops/ascend/test_tbe_ops/test_reciprocal.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_reciprocal.py rename to tests/st/ops/ascend/test_tbe_ops/test_reciprocal.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_relu.py b/tests/st/ops/ascend/test_tbe_ops/test_relu.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_relu.py rename to tests/st/ops/ascend/test_tbe_ops/test_relu.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_relu_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_relu_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_relu_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_relu_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor.py b/tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor.py rename to tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_resize_nearest_neighbor_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_resize_nearest_neighbor_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_scatter_nd.py b/tests/st/ops/ascend/test_tbe_ops/test_scatter_nd.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_scatter_nd.py rename to tests/st/ops/ascend/test_tbe_ops/test_scatter_nd.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_select.py b/tests/st/ops/ascend/test_tbe_ops/test_select.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_select.py rename to tests/st/ops/ascend/test_tbe_ops/test_select.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid_cross_entropy_with_logits_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sigmoid_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_sigmoid_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sigmoid_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_sigmoid_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_slice.py b/tests/st/ops/ascend/test_tbe_ops/test_slice.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_slice.py rename to tests/st/ops/ascend/test_tbe_ops/test_slice.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py b/tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss.py rename to tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_smooth_l1_loss_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_smooth_l1_loss_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_softmax.py b/tests/st/ops/ascend/test_tbe_ops/test_softmax.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_softmax.py rename to tests/st/ops/ascend/test_tbe_ops/test_softmax.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_softmax_cross_entropy_with_logits.py b/tests/st/ops/ascend/test_tbe_ops/test_softmax_cross_entropy_with_logits.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_softmax_cross_entropy_with_logits.py rename to tests/st/ops/ascend/test_tbe_ops/test_softmax_cross_entropy_with_logits.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_split.py b/tests/st/ops/ascend/test_tbe_ops/test_split.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_split.py rename to tests/st/ops/ascend/test_tbe_ops/test_split.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sqrt.py b/tests/st/ops/ascend/test_tbe_ops/test_sqrt.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sqrt.py rename to tests/st/ops/ascend/test_tbe_ops/test_sqrt.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_square.py b/tests/st/ops/ascend/test_tbe_ops/test_square.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_square.py rename to tests/st/ops/ascend/test_tbe_ops/test_square.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_stridedslice.py b/tests/st/ops/ascend/test_tbe_ops/test_stridedslice.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_stridedslice.py rename to tests/st/ops/ascend/test_tbe_ops/test_stridedslice.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_stridedslice_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_stridedslice_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_stridedslice_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_stridedslice_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_sub.py b/tests/st/ops/ascend/test_tbe_ops/test_sub.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_sub.py rename to tests/st/ops/ascend/test_tbe_ops/test_sub.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_tanh.py b/tests/st/ops/ascend/test_tbe_ops/test_tanh.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_tanh.py rename to tests/st/ops/ascend/test_tbe_ops/test_tanh.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_tanh_grad.py b/tests/st/ops/ascend/test_tbe_ops/test_tanh_grad.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_tanh_grad.py rename to tests/st/ops/ascend/test_tbe_ops/test_tanh_grad.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_tile.py b/tests/st/ops/ascend/test_tbe_ops/test_tile.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_tile.py rename to tests/st/ops/ascend/test_tbe_ops/test_tile.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_topk.py b/tests/st/ops/ascend/test_tbe_ops/test_topk.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_topk.py rename to tests/st/ops/ascend/test_tbe_ops/test_topk.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_transpose_d.py b/tests/st/ops/ascend/test_tbe_ops/test_transpose_d.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_transpose_d.py rename to tests/st/ops/ascend/test_tbe_ops/test_transpose_d.py diff --git a/tests/st/ops/davinci/test_tbe_ops/test_unsorted_segment_sum.py b/tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py similarity index 100% rename from tests/st/ops/davinci/test_tbe_ops/test_unsorted_segment_sum.py rename to tests/st/ops/ascend/test_tbe_ops/test_unsorted_segment_sum.py diff --git a/tests/st/ops/davinci/test_tdt_data_ms.py b/tests/st/ops/ascend/test_tdt_data_ms.py similarity index 100% rename from tests/st/ops/davinci/test_tdt_data_ms.py rename to tests/st/ops/ascend/test_tdt_data_ms.py From b902f485a66f199830a16afb5299ebb99011b08c Mon Sep 17 00:00:00 2001 From: dengwentao Date: Mon, 20 Apr 2020 16:10:23 +0800 Subject: [PATCH 100/142] check tbe attr required --- mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index 496f99df1c..d2ad014e5b 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -383,6 +383,10 @@ bool TbeKernelJsonCreator::GenTbeAttrJson(const std::shared_ptr &anf_no attr_obj["name"] = attr_name; attr_obj["valid"] = true; (*attrs_json).push_back(attr_obj); + } else { + if (attr_ptr->param_type() == "required" && creater_type_ == SINGLE_BUILD && op_info->impl_path() != "") { + MS_LOG(EXCEPTION) << "op name: " << op_info->op_name() << " attr: " << attr_name << "is required, but not set."; + } } } return true; From f38d18c6657e2046984e22e9df0038e8d84c4cf1 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Wed, 22 Apr 2020 12:06:05 +0800 Subject: [PATCH 101/142] fix bug in checkpoint when save scaler --- mindspore/nn/wrap/cell_wrapper.py | 2 +- mindspore/train/serialization.py | 7 ++-- tests/ut/python/nn/test_parameter.py | 57 ++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 6c88b7d957..641558921a 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -344,5 +344,5 @@ class ParameterUpdate(Cell): self._param = param def construct(self, x): - self._param = x + F.assign(self._param, x) return x diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index e933d40666..49cc5318fa 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -408,10 +408,11 @@ def _fill_param_into_net(net, parameter_list): for each_param in parameter_list: param_name = each_param["name"] np_val = each_param["data"].asnumpy() - if np_val.shape == (1,): # to scalar - parameter_dict[param_name] = Parameter(np_val[0], name=param_name) + if np_val.shape == (1,): + parameter_dict[param_name] = Parameter(np_val, name=param_name) elif np_val.shape == (): - parameter_dict[param_name] = Parameter(np_val.tolist(), name=param_name) + parameter_dict[param_name] = Parameter(Tensor(np_val.tolist(), mstype.pytype_to_dtype(np_val.dtype)), + name=param_name) else: parameter_dict[param_name] = Parameter(Tensor(np_val), name=param_name) diff --git a/tests/ut/python/nn/test_parameter.py b/tests/ut/python/nn/test_parameter.py index 529af532f7..d6bc40ba02 100644 --- a/tests/ut/python/nn/test_parameter.py +++ b/tests/ut/python/nn/test_parameter.py @@ -52,12 +52,69 @@ def test_parameter_tuple_illegal(): def test_parameter_init_illegal(): + import numpy as np + dat = np.array([[1, 2, 3], [2, 3, 4]]) + tensor = Tensor(dat) + data_none = None data_bool = True data_str = "nicai" + data_int = 3 + data_list = [1, "2", True] + data_tuple = (1, 2, 3) + + # test data + Parameter(tensor, name=data_str) + Parameter(data_int, name=data_str) + Parameter(dat, name=data_str) with pytest.raises(ValueError): Parameter(data_bool, name=data_str) + # test name + Parameter(tensor, name=data_none) + with pytest.raises(ValueError): + Parameter(tensor, name=dat) + with pytest.raises(ValueError): + Parameter(tensor, name=tensor) + with pytest.raises(ValueError): + Parameter(tensor, name=data_bool) + with pytest.raises(ValueError): + Parameter(tensor, name=data_int) + with pytest.raises(ValueError): + Parameter(tensor, name=data_list) + with pytest.raises(ValueError): + Parameter(tensor, name=data_tuple) + + Parameter(tensor, name=data_str, requires_grad=data_bool) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_none) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=dat) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=tensor) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_str) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_int) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_list) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_tuple) + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_bool) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=dat) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=tensor) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_none) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_str) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_int) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_list) + with pytest.raises(TypeError): + Parameter(tensor, name=data_str, requires_grad=data_bool,layerwise_parallel=data_tuple) def test_check_str_by_regular(): From 146ac1263e068c7be39bbe8057acab2893882d5c Mon Sep 17 00:00:00 2001 From: YuJianfeng Date: Wed, 22 Apr 2020 10:40:59 +0800 Subject: [PATCH 102/142] Overlength functions rectification --- .../ascend/ascend_backend_optimization.cc | 53 +++++++++++-------- .../ir_fusion/parameter_and_transop_fusion.cc | 31 +++++++---- 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 496a9b276f..a2d82525e9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -70,6 +70,35 @@ namespace mindspore { namespace opt { +namespace { +void AddAscendBackendOptionalIRFusion(PassManager *ir_fusion_pm) { + MS_EXCEPTION_IF_NULL(ir_fusion_pm); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); +} +} // namespace + void RunOpAscendDataLayout(const std::shared_ptr &kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); auto optimizer = std::make_shared(); @@ -164,29 +193,7 @@ void AscendBackendIRFusionOptimization(const std::shared_ptrAddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); if (context_ptr->ir_fusion_flag()) { - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); - ir_fusion_pm->AddPass(std::make_shared()); + AddAscendBackendOptionalIRFusion(ir_fusion_pm.get()); } if (context_ptr->enable_task_sink() && context_ptr->loop_sink_flag() && ConfigManager::GetInstance().iter_num() > 1) { diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc index faa1308f8b..fe9b35a5e9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -26,6 +26,7 @@ namespace mindspore { namespace opt { +namespace { const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr &node, bool first_flag, std::vector *trans_road) { if (node == nullptr) { @@ -59,6 +60,24 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr return nullptr; } +kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, + TypeId output_type) { + MS_EXCEPTION_IF_NULL(cast); + auto kernel_info = cast->kernel_info(); + MS_EXCEPTION_IF_NULL(kernel_info); + auto cast_build_info = kernel_info->select_kernel_build_info(); + MS_EXCEPTION_IF_NULL(cast_build_info); + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + builder.SetOutputsFormat({format}); + builder.SetInputsFormat({format}); + builder.SetInputsDeviceType({input_type}); + builder.SetOutputsDeviceType({output_type}); + builder.SetKernelType(cast_build_info->kernel_type()); + builder.SetFusionType(cast_build_info->fusion_type()); + builder.SetProcessor(cast_build_info->processor()); + return builder.Build(); +} +} // namespace bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { if (func_graph == nullptr) { MS_LOG(ERROR) << "Func graph is nullptr"; @@ -95,17 +114,7 @@ bool ParameterTransOpFusion::Run(const FuncGraphPtr &func_graph) { auto param_dtype = AnfAlgo::GetOutputDeviceDataType(final_node, 0); auto cast = trans_road[1]; - auto cast_format = AnfAlgo::GetOutputFormat(cast, 0); - auto cast_build_info = cast->kernel_info()->select_kernel_build_info(); - kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; - builder.SetOutputsFormat({format}); - builder.SetInputsFormat({format}); - builder.SetInputsDeviceType({param_dtype}); - builder.SetOutputsDeviceType({dtype}); - builder.SetKernelType(cast_build_info->kernel_type()); - builder.SetFusionType(cast_build_info->fusion_type()); - builder.SetProcessor(cast_build_info->processor()); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cast.get()); + AnfAlgo::SetSelectKernelBuildInfo(GetKernelBuildInfo(cast, format, param_dtype, dtype), cast.get()); if (param_format == format && param_dtype != dtype) { manager->Replace(trans_road[2], final_node); manager->Replace(cur_transop, cast); From f3bf699b8db970873b5666f0dcdf4e3096149047 Mon Sep 17 00:00:00 2001 From: jiangjinsheng Date: Wed, 22 Apr 2020 16:39:56 +0800 Subject: [PATCH 103/142] add example for maxpool --- mindspore/ops/operations/nn_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 23ddd9f021..61a7d6e833 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -915,7 +915,7 @@ class MaxPool(_Pool): Tensor, with shape :math:`(N, C_{out}, H_{out}, W_{out})`. Examples: - >>> input_tensor = Tensor(np.arange(1*3*3*4).reshape((1,3,3,4)),mindspore.float32) + >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) >>> maxpool_op = P.MaxPool(padding="VALID", ksize=2, strides=1) >>> output_tensor = maxpool_op(input_tensor) """ @@ -966,7 +966,7 @@ class MaxPoolWithArgmax(_Pool): - **mask** (Tensor) - Max values' index represented by the mask. Examples: - >>> input_tensor = Tensor(np.arange(1*3*3*4).reshape((1,3,3,4)),mindspore.float32) + >>> input_tensor = Tensor(np.arange(1 * 3 * 3 * 4).reshape((1, 3, 3, 4)), mindspore.float32) >>> maxpool_arg_op = P.MaxPoolWithArgmax(padding="VALID", ksize=2, strides=1) >>> output_tensor, argmax = maxpool_arg_op(input_tensor) """ From ba43dbc148fb0b0b4719cf59aee5752fa972f7f8 Mon Sep 17 00:00:00 2001 From: leonwanghui Date: Wed, 22 Apr 2020 16:44:19 +0800 Subject: [PATCH 104/142] Fix pylint warnings in mindspore st test module --- .../st/auto_parallel/onehot_model_parallel.py | 38 ++++---- .../soft_entropy_loss_expand_parallel.py | 90 +++++++++++-------- tests/st/auto_parallel/test_expand_loss.py | 4 +- .../test_model_parallel_onehot.py | 3 +- .../test_resnet50_expand_loss_2p.py | 27 +++--- tests/st/control/test_while.py | 14 +-- .../st/fusion/test_add_relu_buffer_fusion.py | 20 ++--- tests/st/fusion/test_conv_bn1_fusion.py | 25 +++--- tests/st/fusion/test_tbe_eltwise_fusion_1.py | 15 ++-- tests/st/fusion/test_tbe_eltwise_fusion_2.py | 17 ++-- .../test_tbe_multi_inout_eltwise_fusion.py | 14 ++- .../fusion/test_tbe_reduce_eltwise_fusion.py | 17 ++-- tests/st/mem_reuse/check_file.py | 3 +- tests/st/mem_reuse/resnet.py | 2 +- tests/st/mem_reuse/resnet_cifar_memreuse.py | 25 +++--- tests/st/mem_reuse/resnet_cifar_normal.py | 12 +-- tests/st/nccl/test_nccl_all.py | 4 + tests/st/nccl/test_nccl_all_gather_op.py | 13 +-- tests/st/nccl/test_nccl_all_reduce_op.py | 19 ++-- tests/st/nccl/test_nccl_lenet.py | 23 ++--- tests/st/nccl/test_nccl_reduce_scatter_op.py | 12 +-- tests/st/networks/models/alexnet.py | 1 + .../models/bert/bert_tdt_no_lossscale.py | 16 ++-- tests/st/networks/models/lenet.py | 3 +- tests/st/networks/models/resnetv1_5.py | 17 ++-- tests/st/networks/test_cpu_lenet.py | 13 ++- tests/st/networks/test_gpu_alexnet.py | 13 +-- tests/st/networks/test_gpu_lenet.py | 10 ++- tests/st/networks/test_gpu_lstm.py | 12 ++- tests/st/networks/test_gpu_resnet.py | 17 ++-- tests/st/networks/test_network_main.py | 22 +++-- tests/st/pynative/test_ascend_lenet.py | 8 +- tests/st/summary/test_davinci_summary.py | 7 +- tests/st/summary/test_gpu_summary.py | 2 +- tests/st/tbe_networks/export_geir.py | 11 +-- tests/st/tbe_networks/resnet.py | 2 +- tests/st/tbe_networks/resnet_cifar.py | 16 ++-- tests/st/tbe_networks/test_resnet_cifar_8p.py | 2 + 38 files changed, 317 insertions(+), 252 deletions(-) diff --git a/tests/st/auto_parallel/onehot_model_parallel.py b/tests/st/auto_parallel/onehot_model_parallel.py index 14b351c0ee..1f35ac1f80 100644 --- a/tests/st/auto_parallel/onehot_model_parallel.py +++ b/tests/st/auto_parallel/onehot_model_parallel.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ import os import pytest @@ -26,6 +27,7 @@ device_num = 2 device_id = int(os.getenv('DEVICE_ID')) rank_id = 0 + def setup_module(): global device_num global rank_id @@ -42,9 +44,11 @@ def setup_module(): context.set_auto_parallel_context(device_num=device_num, global_rank=rank_id) + def teardown_module(): distributedTool.release() + class Onehot(Cell): def __init__(self, axis=-1, depth=1, on_value=1.0, off_value=0.0, strategy=None): super(Onehot, self).__init__() @@ -56,25 +60,26 @@ class Onehot(Cell): self.on_value = Tensor(on_value, ms.float32) self.off_value = Tensor(off_value, ms.float32) self.transpose = P.Transpose().set_strategy(strategy=trans_stra) - self.sub = P.Sub().set_strategy(strategy=((1,1),(1,1))) + self.sub = P.Sub().set_strategy(strategy=((1, 1), (1, 1))) def construct(self, input, indices): x = self.onehot(indices, self.depth, self.on_value, self.off_value) - x = self.transpose(x, (1,0)) + x = self.transpose(x, (1, 0)) x = self.sub(input, x) return x + class DataGenerator(): def get_parallel_blocks(self, input_, strategy): blocks = [input_] i = 0 for stra in strategy: temp = [] - while len(blocks)>0: + while len(blocks) > 0: block = blocks.pop(0) temp.extend(np.split(block, stra, axis=i)) blocks.extend(temp) - i+=1 + i += 1 return blocks def generate_data(self, shape): @@ -93,32 +98,33 @@ class DataGenerator(): stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data),Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) + class OneHotFactory: def __init__(self, batch_size, classes, on_value=1.0, off_value=0.0, axis=None, strategy=None): dataGen = DataGenerator() self.input_full, self.input_part = dataGen.input_data((classes, batch_size)) - self.label_full, self.label_part = dataGen.label_data((batch_size,),classes) + self.label_full, self.label_part = dataGen.label_data((batch_size,), classes) self.depth = classes self.on_value = on_value self.off_value = off_value self.axis = axis self.strategy = strategy - + def forward_mindspore_single_impl(self): - net = Onehot(axis=self.axis, - depth=self.depth, - on_value=self.on_value, + net = Onehot(axis=self.axis, + depth=self.depth, + on_value=self.on_value, off_value=self.off_value) out = net(self.input_full, self.label_full) return out - + def forward_mindspore_parallel_impl(self): context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - net = Onehot(axis=self.axis, - depth=self.depth, - on_value=self.on_value, + net = Onehot(axis=self.axis, + depth=self.depth, + on_value=self.on_value, off_value=self.off_value, strategy=self.strategy) out = net.compile_and_run(self.input_full, self.label_full) return out @@ -137,7 +143,7 @@ def test_reid_onehot_forward_int32_128_depth1024_model_parallel(): on_value=1.000000, off_value=0.000000, axis=-1, - strategy=((1,device_num),(),())) + strategy=((1, device_num), (), ())) fact.forward_cmp() @@ -147,5 +153,5 @@ def test_reid_onehot_forward_int32_1024_depth128_model_parallel(): on_value=1.000000, off_value=0.000000, axis=-1, - strategy=((1,device_num),(),())) + strategy=((1, device_num), (), ())) fact.forward_cmp() diff --git a/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py b/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py index 17dbe8f304..86a8b89521 100644 --- a/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py +++ b/tests/st/auto_parallel/soft_entropy_loss_expand_parallel.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ import os import pytest @@ -31,7 +32,7 @@ from mindspore.nn.optim.momentum import Momentum from mindspore.train.callback import Callback np.set_printoptions(threshold=np.inf) -device_num=2 +device_num = 2 device_id = int(os.getenv('DEVICE_ID')) rank_id = 0 embed = 128 @@ -39,6 +40,7 @@ classes = 32 batch_size = 32*2 MatmulParamShape = (classes, embed) + def setup_module(): global device_num global rank_id @@ -55,26 +57,28 @@ def setup_module(): context.set_auto_parallel_context(device_num=device_num, global_rank=device_id) + def teardown_module(): distributedTool.release() + class DataGenerator(): def get_parallel_blocks(self, input_, strategy): blocks = [input_] i = 0 for stra in strategy: temp = [] - while len(blocks)>0: + while len(blocks) > 0: block = blocks.pop(0) temp.extend(np.split(block, stra, axis=i)) blocks.extend(temp) - i+=1 + i += 1 return blocks def generate_data(self, shape): size = np.cumprod(shape)[-1] num_range = min(size, 1000) - data = (np.arange(0, size)%num_range)/num_range + data = (np.arange(0, size) % num_range)/num_range data = np.reshape(data, shape) return data @@ -83,14 +87,15 @@ class DataGenerator(): stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data), Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) def label_data(self, shape, embed): data = (self.generate_data(shape)*(embed-1)).astype(np.int32) stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data),Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) + class Dataset(): def __init__(self, predict, label, length=1, input_num=2): @@ -121,15 +126,18 @@ class Dataset(): def get_repeat_count(self): return self.length + class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() self.loss_list = [] + def epoch_end(self, run_context, *args): cb_params = run_context.original_args() result = cb_params.net_outputs self.loss_list.append(result.asnumpy().mean()) + class SoftmaxCrossEntropyExpand(Cell): def __init__(self, sparse=False, stra_list=[]): super(SoftmaxCrossEntropyExpand, self).__init__() @@ -164,22 +172,25 @@ class SoftmaxCrossEntropyExpand(Cell): loss = self.reduce_mean(loss, -1) return loss + class MatmulNet(Cell): - def __init__(self, matmul_stra = None, loss_stra_list=[]): + def __init__(self, matmul_stra=None, loss_stra_list=[]): super(MatmulNet, self).__init__() self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy=matmul_stra) self.loss = SoftmaxCrossEntropyExpand(sparse=True, stra_list=loss_stra_list) - self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight") + self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight") + def construct(self, x, label): loss_input = self.matmul(x, self.weight) out = self.loss(loss_input, label) return out + class LossFactory(): def __init__(self): dataGen = DataGenerator() self.input_full, self.input_part = dataGen.input_data((batch_size, embed)) - self.label_full, self.label_part = dataGen.label_data((batch_size,),embed) + self.label_full, self.label_part = dataGen.label_data((batch_size,), embed) def single_matmul_trains(self): single_callback = ModelCallback() @@ -196,32 +207,33 @@ class LossFactory(): parallel_callback = ModelCallback() context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") net = MatmulNet() - optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, optimizer=optimizer) epoch_size = 6 dataset = Dataset(self.input_part, self.label_part) model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False) loss_value = np.array(parallel_callback.loss_list) return loss_value - + def model_parallel_matmul_trains(self): parallel_callback = ModelCallback() - matmul_stra = ((1,1),(device_num,1)) - reduce_max_stra = ((1,device_num),) - sub_stra = ((1,device_num),(1,1)) - exp_stra = ((1,device_num),) - reduce_sum_stra = ((1,device_num),) - div_stra = ((1,device_num),(1,1)) - log_stra = ((1,device_num),) - mul_stra = ((1,device_num),(1,device_num)) - sum_cross_entropy_stra = ((1,device_num),) - mul2_stra = ((),(device_num,)) + matmul_stra = ((1, 1), (device_num, 1)) + reduce_max_stra = ((1, device_num),) + sub_stra = ((1, device_num), (1, 1)) + exp_stra = ((1, device_num),) + reduce_sum_stra = ((1, device_num),) + div_stra = ((1, device_num), (1, 1)) + log_stra = ((1, device_num),) + mul_stra = ((1, device_num), (1, device_num)) + sum_cross_entropy_stra = ((1, device_num),) + mul2_stra = ((), (device_num,)) reduce_mean_stra = ((device_num,),) - onehot_stra = ((1,device_num),(),()) - loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] + onehot_stra = ((1, device_num), (), ()) + loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, + sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] context.set_auto_parallel_context(parallel_mode="auto_parallel") - net = MatmulNet(matmul_stra = matmul_stra, loss_stra_list = loss_stra_list) - optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, optimizer=optimizer) epoch_size = 6 dataset = Dataset(self.input_part, self.label_part) @@ -231,21 +243,22 @@ class LossFactory(): def mix_parallel_matmul_trains(self): parallel_callback = ModelCallback() - matmul_stra = ((device_num,1),(1,1)) - reduce_max_stra = ((1,device_num),) - sub_stra = ((device_num,1),(device_num,1)) - exp_stra = ((1,device_num),) - reduce_sum_stra = ((1,device_num),) - div_stra = ((1,device_num),(1,1)) - log_stra = ((1,device_num),) - mul_stra = ((1,device_num),(1,device_num)) - sum_cross_entropy_stra = ((1,device_num),) - mul2_stra = ((),(device_num,)) + matmul_stra = ((device_num, 1), (1, 1)) + reduce_max_stra = ((1, device_num),) + sub_stra = ((device_num, 1), (device_num, 1)) + exp_stra = ((1, device_num),) + reduce_sum_stra = ((1, device_num),) + div_stra = ((1, device_num), (1, 1)) + log_stra = ((1, device_num),) + mul_stra = ((1, device_num), (1, device_num)) + sum_cross_entropy_stra = ((1, device_num),) + mul2_stra = ((), (device_num,)) reduce_mean_stra = ((device_num,),) - onehot_stra = ((1,device_num),(),()) - loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] + onehot_stra = ((1, device_num), (), ()) + loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra, + sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra] context.set_auto_parallel_context(parallel_mode="auto_parallel") - net = MatmulNet(matmul_stra = matmul_stra, loss_stra_list = loss_stra_list) + net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list) optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) model = Model(net, optimizer=optimizer) epoch_size = 6 @@ -254,6 +267,7 @@ class LossFactory(): loss_value = np.array(parallel_callback.loss_list) return loss_value + def test_all_trains(): loss_factory = LossFactory() context.reset_auto_parallel_context() diff --git a/tests/st/auto_parallel/test_expand_loss.py b/tests/st/auto_parallel/test_expand_loss.py index 786cbff980..e89ee5d3c8 100644 --- a/tests/st/auto_parallel/test_expand_loss.py +++ b/tests/st/auto_parallel/test_expand_loss.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - import os import pytest + @pytest.mark.level0 @pytest.mark.platform_x86_ascend_training @pytest.mark.platform_arm_ascend_training @@ -23,4 +23,4 @@ import pytest def test_expand_loss(): sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/run_auto_parallel_loss_expand.sh") - assert(ret==0) + assert(ret == 0) diff --git a/tests/st/auto_parallel/test_model_parallel_onehot.py b/tests/st/auto_parallel/test_model_parallel_onehot.py index 1df7ad1e99..55217421a4 100644 --- a/tests/st/auto_parallel/test_model_parallel_onehot.py +++ b/tests/st/auto_parallel/test_model_parallel_onehot.py @@ -16,6 +16,7 @@ import os import pytest + def test_expand_loss(): ret = os.system("sh run_onehot_model_parallel.sh") - assert(ret==0) + assert(ret == 0) diff --git a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py b/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py index 62711ccf6a..b28ad510e3 100644 --- a/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py +++ b/tests/st/auto_parallel/test_resnet50_expand_loss_2p.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================ +import os import numpy as np import pytest -from numpy import allclose +import mindspore.context as context import mindspore.nn as nn import mindspore.common.dtype as mstype from mindspore import Tensor @@ -22,21 +24,21 @@ from mindspore.ops import operations as P from mindspore.nn.optim.momentum import Momentum from mindspore.common.initializer import One from mindspore.train.model import Model, ParallelMode -from mindspore import context -import os from mindspore.communication.management import init import mindspore.ops.functional as F from mindspore.nn.loss.loss import _Loss from mindspore.train.callback import Callback from mindspore.parallel import set_algo_parameters + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_hccl=True) -context.set_context(enable_task_sink=True,device_id=int(os.getenv('DEVICE_ID'))) +context.set_context(enable_task_sink=True, device_id=int(os.getenv('DEVICE_ID'))) context.set_context(enable_ir_fusion=True) context.set_context(enable_loop_sink=False) init() context.set_auto_parallel_context(mirror_mean=True, parallel_mode=ParallelMode.AUTO_PARALLEL) + def weight_variable(shape, factor=0.1): return One() @@ -52,6 +54,7 @@ def _conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value) + def _conv7x7(in_channels, out_channels, stride=1, padding=0, pad_mode='same'): init_value = weight_variable((out_channels, in_channels, 7, 7)) return nn.Conv2d(in_channels, out_channels, @@ -63,6 +66,7 @@ def _fused_bn(channels, momentum=0.9): init_bias = weight_variable((channels,)) return nn.BatchNorm2d(channels, momentum=momentum) + class BasicBlock(nn.Cell): expansion = 1 @@ -172,7 +176,7 @@ class ResNet(nn.Cell): layer_nums, in_channels, out_channels, - strides=[1,2,2,2], + strides=[1, 2, 2, 2], num_classes=100): super(ResNet, self).__init__() @@ -292,17 +296,19 @@ class SoftmaxCrossEntropyExpand(_Loss): rank_id = int(os.environ["RANK_ID"]) device_num = int(os.environ["RANK_SIZE"]) + + class DataGenerator(): def get_parallel_blocks(self, input_, strategy): blocks = [input_] i = 0 for stra in strategy: temp = [] - while len(blocks)>0: + while len(blocks) > 0: block = blocks.pop(0) temp.extend(np.split(block, stra, axis=i)) blocks.extend(temp) - i+=1 + i += 1 return blocks def generate_data(self, shape): @@ -321,7 +327,7 @@ class DataGenerator(): stra = [1]*len(shape) stra[0] = device_num datas = self.get_parallel_blocks(data, stra) - return Tensor(data),Tensor(datas[rank_id]) + return Tensor(data), Tensor(datas[rank_id]) class Dataset(): @@ -359,6 +365,7 @@ class ModelCallback(Callback): def __init__(self): super(ModelCallback, self).__init__() self.loss_list = [] + def epoch_end(self, run_context, *args): cb_params = run_context.original_args() result = cb_params.net_outputs @@ -382,7 +389,7 @@ def test_train_feed(num_classes=8192): model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) expect_out = [9.010913, 8.855984, 8.56246, 8.146317, 7.624489] - assert allclose(loss_value, expect_out, 0.0001, 0.0001) + assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) @pytest.mark.level0 @@ -402,4 +409,4 @@ def test_train_feed2(num_classes=1001): model.train(5, dataset, dataset_sink_mode=False, callbacks=parallel_callback) loss_value = np.array(parallel_callback.loss_list) expect_out = [6.908755, 6.8358116, 6.6986914, 6.506859, 6.2708097] - assert allclose(loss_value, expect_out, 0.0001, 0.0001) + assert np.allclose(loss_value, expect_out, 0.0001, 0.0001) diff --git a/tests/st/control/test_while.py b/tests/st/control/test_while.py index 56b38f7f9a..6c659b6581 100644 --- a/tests/st/control/test_while.py +++ b/tests/st/control/test_while.py @@ -13,12 +13,12 @@ # limitations under the License. # ============================================================================ import numpy as np -from mindspore.common.tensor import Tensor -from mindspore.common import dtype as mstype import mindspore.context as context -from mindspore.ops import operations as P import mindspore.nn as nn -from mindspore.common import ms_function +from mindspore import Tensor, ms_function +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + @ms_function def t1_while(x, y, z): @@ -28,8 +28,9 @@ def t1_while(x, y, z): x = x + 3 return x + def test_net(): - context.set_context(mode=context.GRAPH_MODE,device_target="Ascend") + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_task_sink=True) c1 = Tensor([2], mstype.int32) c2 = Tensor([14], mstype.int32) @@ -38,5 +39,6 @@ def test_net(): ret = t1_while(c1, c2, c3) assert (ret == expect) + if __name__ == "__main__": - test_net() \ No newline at end of file + test_net() diff --git a/tests/st/fusion/test_add_relu_buffer_fusion.py b/tests/st/fusion/test_add_relu_buffer_fusion.py index fbb0f73991..d63c8b355d 100644 --- a/tests/st/fusion/test_add_relu_buffer_fusion.py +++ b/tests/st/fusion/test_add_relu_buffer_fusion.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor, ms_function +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_id=5, device_target="Ascend") -#context.set_context(enable_task_sink=True) + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -35,17 +34,14 @@ class Net(nn.Cell): def construct(self, x, y): x = self.cast(x, mstype.float16) y = self.cast(y, mstype.float16) - #x = self.softmax(x) x = self.add(x, y) - #x = self.relu(x) x = self.relu(x) - #x = self.softmax(x) x = self.reduce_mean(x) return x + def test_net(): x = np.random.randn(32, 10).astype(np.float32) relu = Net() output = relu(Tensor(x), Tensor(x)) - print(x) print(output.asnumpy()) diff --git a/tests/st/fusion/test_conv_bn1_fusion.py b/tests/st/fusion/test_conv_bn1_fusion.py index 6149b9fd71..c3547ae1cf 100644 --- a/tests/st/fusion/test_conv_bn1_fusion.py +++ b/tests/st/fusion/test_conv_bn1_fusion.py @@ -13,15 +13,13 @@ # limitations under the License. # ============================================================================ import numpy as np +import mindspore.context as context import mindspore.nn as nn +from mindspore import Tensor, Parameter, Model, ms_function from mindspore.ops import operations as P from mindspore.common.initializer import initializer -from mindspore import Tensor, Parameter, Model from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.optim import Momentum -from mindspore.common.api import ms_function -import mindspore.nn as wrap -import mindspore.context as context context.set_context(device_target="Ascend", enable_task_sink=True) @@ -35,6 +33,7 @@ class MsWrapper(nn.Cell): def __init__(self, network): super(MsWrapper, self).__init__(auto_prefix=False) self._network = network + @ms_function def construct(self, *args): return self._network(*args) @@ -42,16 +41,16 @@ class MsWrapper(nn.Cell): def me_train_tensor(net, input_np, label_np, epoch_size=2): loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - opt = nn.Momentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), filter(lambda x: x.requires_grad, net.get_parameters())) + opt = nn.Momentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), + filter(lambda x: x.requires_grad, net.get_parameters())) context.set_context(mode=context.GRAPH_MODE) Model(net, loss, opt) - _network = wrap.WithLossCell(net, loss) - _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) + _network = nn.WithLossCell(net, loss) + _train_net = MsWrapper(nn.TrainOneStepCell(_network, opt)) _train_net.set_train() for epoch in range(0, epoch_size): - print(f"epoch %d"%(epoch)) + print(f"epoch %d" % (epoch)) output = _train_net(Tensor(input_np), Tensor(label_np)) - print("********output***********") print(output.asnumpy()) @@ -60,9 +59,9 @@ def test_conv_bn_add_relu_fusion(): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.conv1 = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.bn = nn.BatchNorm2d(output_channel, momentum=0.1, eps=0.0001) self.add = P.TensorAdd() self.relu = P.ReLU() @@ -91,7 +90,7 @@ def test_conv_bn_relu_fusion(): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.bn = nn.BatchNorm2d(output_channel, momentum=0.1, eps=0.0001) self.relu = P.ReLU() self.mean = P.ReduceMean(keep_dims=True) @@ -118,7 +117,7 @@ def test_conv_bn_fusion(): def __init__(self): super(Net, self).__init__() self.conv = nn.Conv2d(input_channel, output_channel, - kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") + kernel_size=1, stride=1, padding=0, has_bias=False, pad_mode="same") self.bn = nn.BatchNorm2d(output_channel, momentum=0.1, eps=0.0001) self.mean = P.ReduceMean(keep_dims=True) self.reshape = P.Reshape() diff --git a/tests/st/fusion/test_tbe_eltwise_fusion_1.py b/tests/st/fusion/test_tbe_eltwise_fusion_1.py index 0b9ae1fe63..5d4fd09d30 100644 --- a/tests/st/fusion/test_tbe_eltwise_fusion_1.py +++ b/tests/st/fusion/test_tbe_eltwise_fusion_1.py @@ -13,16 +13,15 @@ # limitations under the License. # ============================================================================ import pytest -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -35,6 +34,7 @@ class Net(nn.Cell): x = self.relu(x) return x + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -43,5 +43,4 @@ def test_net(): x = np.random.randn(32, 10).astype(np.float32) relu_relu = Net() output = relu_relu(Tensor(x)) - print(x) print(output.asnumpy()) diff --git a/tests/st/fusion/test_tbe_eltwise_fusion_2.py b/tests/st/fusion/test_tbe_eltwise_fusion_2.py index 8f6084ae5b..3ae754d385 100644 --- a/tests/st/fusion/test_tbe_eltwise_fusion_2.py +++ b/tests/st/fusion/test_tbe_eltwise_fusion_2.py @@ -13,16 +13,15 @@ # limitations under the License. # ============================================================================ import pytest -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -41,6 +40,7 @@ class Net(nn.Cell): x = self.relu(x) return x + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -50,5 +50,4 @@ def test_net(): y = np.random.randn(10).astype(np.float32) net = Net() output = net(Tensor(x), Tensor(y)) - print(x) - print(output.asnumpy()) \ No newline at end of file + print(output.asnumpy()) diff --git a/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py b/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py index 9a900a4739..76cf639da0 100644 --- a/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py +++ b/tests/st/fusion/test_tbe_multi_inout_eltwise_fusion.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn -import mindspore.common.dtype as mstype import numpy as np import mindspore.context as context -from mindspore.common.parameter import Parameter +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + context.set_context(mode=context.GRAPH_MODE, device_id=4, device_target="Ascend") -#context.set_context(enable_task_sink=True) + class Net(nn.Cell): def __init__(self): @@ -39,6 +38,7 @@ class Net(nn.Cell): z = self.add(z1, z2) return z + def test_net(): x = np.random.randn(32, 10).astype(np.float32) y = np.random.randn(32, 10).astype(np.float32) @@ -46,6 +46,4 @@ def test_net(): h = np.random.randn(10).astype(np.float32) relu_relu = Net() output = relu_relu(Tensor(x), Tensor(y), Tensor(k), Tensor(h)) - print(x) print(output.asnumpy()) - diff --git a/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py b/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py index 63b1cc542d..93c7174b52 100644 --- a/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py +++ b/tests/st/fusion/test_tbe_reduce_eltwise_fusion.py @@ -13,17 +13,16 @@ # limitations under the License. # ============================================================================ import pytest +import numpy as np +import mindspore.context as context +import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P from mindspore.ops.operations import _grad_ops as G -import mindspore.nn as nn -from mindspore.common.api import ms_function -import mindspore.common.dtype as mstype -import numpy as np -import mindspore.context as context -from mindspore.common.initializer import initializer -from mindspore.common.parameter import Parameter + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + class Net(nn.Cell): def __init__(self): super(Net, self).__init__() @@ -41,6 +40,7 @@ class Net(nn.Cell): x = self.relu(x) return x + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -49,5 +49,4 @@ def test_net(): x = np.random.randn(32, 10).astype(np.float32) net = Net() output = net(Tensor(x)) - print(x) - print(output.asnumpy()) \ No newline at end of file + print(output.asnumpy()) diff --git a/tests/st/mem_reuse/check_file.py b/tests/st/mem_reuse/check_file.py index 2f6fe82d2d..30b3b690a4 100644 --- a/tests/st/mem_reuse/check_file.py +++ b/tests/st/mem_reuse/check_file.py @@ -14,6 +14,7 @@ # ============================================================================ import os import filecmp + curr_path = os.path.abspath(os.curdir) file_memreuse = curr_path + "/mem_reuse_check/memreuse.ir" file_normal = curr_path + "/mem_reuse_check/normal_mem.ir" @@ -23,5 +24,3 @@ checker = os.path.exists(file_normal) assert (checker, True) checker = filecmp.cmp(file_memreuse, file_normal) assert (checker, True) - - diff --git a/tests/st/mem_reuse/resnet.py b/tests/st/mem_reuse/resnet.py index fb4075f0a4..1c1b880b57 100644 --- a/tests/st/mem_reuse/resnet.py +++ b/tests/st/mem_reuse/resnet.py @@ -19,6 +19,7 @@ from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common import dtype as mstype + def weight_variable(shape): return initializer('XavierUniform', shape=shape, dtype=mstype.float32) @@ -297,4 +298,3 @@ class ResNet(nn.Cell): def resnet50(batch_size, num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes, batch_size) - diff --git a/tests/st/mem_reuse/resnet_cifar_memreuse.py b/tests/st/mem_reuse/resnet_cifar_memreuse.py index 4edcdd8fb8..d6310612b6 100644 --- a/tests/st/mem_reuse/resnet_cifar_memreuse.py +++ b/tests/st/mem_reuse/resnet_cifar_memreuse.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import argparse +import os +import numpy as np +import mindspore.context as context import mindspore.nn as nn +import mindspore.common.dtype as mstype from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model, ParallelMode -from mindspore import context -import mindspore.common.dtype as mstype -import os -import numpy as np -import mindspore.ops.functional as F from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset as de @@ -30,11 +31,11 @@ import mindspore.dataset.transforms.vision.c_transforms as vision from mindspore.communication.management import init from resnet import resnet50 import random + random.seed(1) np.random.seed(1) de.config.set_seed(1) -import argparse parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--device_num', type=int, default=1, help='Device num.') @@ -47,9 +48,9 @@ parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoin parser.add_argument('--dataset_path', type=str, default="/var/log/npu/datasets/cifar", help='Dataset path') args_opt = parser.parse_args() -device_id=int(os.getenv('DEVICE_ID')) +device_id = int(os.getenv('DEVICE_ID')) -data_home=args_opt.dataset_path +data_home = args_opt.dataset_path context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_task_sink=True, device_id=device_id) @@ -64,8 +65,8 @@ def create_dataset(repeat_num=1, training=True): ds = de.Cifar10Dataset(data_dir) if args_opt.run_distribute: - rank_id=int(os.getenv('RANK_ID')) - rank_size=int(os.getenv('RANK_SIZE')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) ds = de.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) resize_height = 224 @@ -74,9 +75,9 @@ def create_dataset(repeat_num=1, training=True): shift = 0.0 # define map operations - random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() - resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR rescale_op = vision.Rescale(rescale, shift) normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) changeswap_op = vision.HWC2CHW() diff --git a/tests/st/mem_reuse/resnet_cifar_normal.py b/tests/st/mem_reuse/resnet_cifar_normal.py index 39f6e7fe59..2b6741e57a 100644 --- a/tests/st/mem_reuse/resnet_cifar_normal.py +++ b/tests/st/mem_reuse/resnet_cifar_normal.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import argparse +import os +import numpy as np +import mindspore.context as context import mindspore.nn as nn +import mindspore.common.dtype as mstype from mindspore import Tensor from mindspore.ops import operations as P +from mindspore.ops import functional as F from mindspore.nn.optim.momentum import Momentum from mindspore.train.model import Model, ParallelMode -from mindspore import context -import mindspore.common.dtype as mstype -import os -import numpy as np -import mindspore.ops.functional as F from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor from mindspore.train.serialization import load_checkpoint, load_param_into_net import mindspore.dataset as de @@ -35,7 +36,6 @@ random.seed(1) np.random.seed(1) de.config.set_seed(1) -import argparse parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') diff --git a/tests/st/nccl/test_nccl_all.py b/tests/st/nccl/test_nccl_all.py index 99494bb741..faa6394f9a 100644 --- a/tests/st/nccl/test_nccl_all.py +++ b/tests/st/nccl/test_nccl_all.py @@ -15,6 +15,7 @@ import os import pytest + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single @@ -22,6 +23,7 @@ def test_nccl_lenet(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_lenet.py") assert(return_code == 0) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single @@ -29,6 +31,7 @@ def test_nccl_all_reduce_op(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_all_reduce_op.py") assert(return_code == 0) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single @@ -36,6 +39,7 @@ def test_nccl_all_gather_op(): return_code = os.system("mpirun -n 8 pytest -s test_nccl_all_gather_op.py") assert(return_code == 0) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_single diff --git a/tests/st/nccl/test_nccl_all_gather_op.py b/tests/st/nccl/test_nccl_all_gather_op.py index f2a2c7133c..0a37a692da 100644 --- a/tests/st/nccl/test_nccl_all_gather_op.py +++ b/tests/st/nccl/test_nccl_all_gather_op.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn import numpy as np import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') init('nccl') rank = get_rank() size = get_group_size() -x = np.ones([1,1,3,3]).astype(np.float32) * 0.01 * (rank + 1) +x = np.ones([1, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + class Net(nn.Cell): - def __init__( self): + def __init__(self): super(Net, self).__init__() self.all_gather = P.AllGather(group=NCCL_WORLD_COMM_GROUP) self.x = Parameter(initializer(Tensor(x), x.shape), name='x') @@ -36,6 +38,7 @@ class Net(nn.Cell): def construct(self): return self.all_gather(self.x) + def test_AllGather(): all_gather = Net() output = all_gather() diff --git a/tests/st/nccl/test_nccl_all_reduce_op.py b/tests/st/nccl/test_nccl_all_reduce_op.py index 3ba8b219e4..a1a732fd08 100644 --- a/tests/st/nccl/test_nccl_all_reduce_op.py +++ b/tests/st/nccl/test_nccl_all_reduce_op.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn import numpy as np import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') init('nccl') rank = get_rank() size = get_group_size() -x = np.ones([3,1,3,3]).astype(np.float32) * 0.01 * (rank + 1) +x = np.ones([3, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + class Net(nn.Cell): - def __init__( self): + def __init__(self): super(Net, self).__init__() self.x1 = Parameter(initializer(Tensor(x), x.shape), name='x1') self.x2 = Parameter(initializer(Tensor(x), x.shape), name='x2') @@ -47,6 +49,7 @@ class Net(nn.Cell): self.all_reduce2(self.x2), self.all_reduce3(self.x3)) + def test_AllReduce(): all_reduce = Net() output = all_reduce() @@ -58,16 +61,16 @@ def test_AllReduce(): diff0 = output[0].asnumpy() - expect0 error0 = np.ones(shape=expect0.shape) * 1.0e-5 assert np.all(diff0 < error0) - assert (output[0].shape() == expect0.shape) + assert output[0].shape() == expect0.shape expect1 = expect0 diff1 = output[1].asnumpy() - expect1 error1 = np.ones(shape=expect1.shape) * 1.0e-5 assert np.all(diff1 < error1) - assert (output[1].shape() == expect1.shape) + assert output[1].shape() == expect1.shape expect2 = expect1 diff2 = output[2].asnumpy() - expect2 error2 = np.ones(shape=expect2.shape) * 1.0e-5 assert np.all(diff2 < error2) - assert (output[2].shape() == expect2.shape) + assert output[2].shape() == expect2.shape diff --git a/tests/st/nccl/test_nccl_lenet.py b/tests/st/nccl/test_nccl_lenet.py index 2aebc5da50..3880f1d473 100644 --- a/tests/st/nccl/test_nccl_lenet.py +++ b/tests/st/nccl/test_nccl_lenet.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -import numpy as np -from mindspore.nn import Dense -import mindspore.nn as nn import datetime +import numpy as np import mindspore.context as context -from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size +import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.ops import operations as P -from mindspore.common.tensor import Tensor +from mindspore.communication.management import init, get_rank, get_group_size context.set_context(mode=context.GRAPH_MODE, device_target="GPU") init('nccl') @@ -31,6 +30,7 @@ total = 5000 batch_size = 32 mini_batch = total // batch_size + class LeNet(nn.Cell): def __init__(self): super(LeNet, self).__init__() @@ -43,15 +43,15 @@ class LeNet(nn.Cell): self.conv2 = nn.Conv2d(6, 16, (5, 5), weight_init=weight2, pad_mode='valid', stride=1, padding=0) self.pool = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode="valid") self.reshape = P.Reshape() - + weight1 = Tensor(np.ones([120, 400]).astype(np.float32) * 0.01) - self.fc1 = Dense(400, 120, weight_init=weight1) - + self.fc1 = nn.Dense(400, 120, weight_init=weight1) + weight2 = Tensor(np.ones([84, 120]).astype(np.float32) * 0.01) - self.fc2 = Dense(120, 84, weight_init=weight2) - + self.fc2 = nn.Dense(120, 84, weight_init=weight2) + weight3 = Tensor(np.ones([10, 84]).astype(np.float32) * 0.01) - self.fc3 = Dense(84, 10, weight_init=weight3) + self.fc3 = nn.Dense(84, 10, weight_init=weight3) def construct(self, input_x): output = self.conv1(input_x) @@ -66,6 +66,7 @@ class LeNet(nn.Cell): output = self.fc3(output) return output + def test_lenet_nccl(): net = LeNet() net.set_train() diff --git a/tests/st/nccl/test_nccl_reduce_scatter_op.py b/tests/st/nccl/test_nccl_reduce_scatter_op.py index 32c1f31788..f3322d07a3 100644 --- a/tests/st/nccl/test_nccl_reduce_scatter_op.py +++ b/tests/st/nccl/test_nccl_reduce_scatter_op.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ -from mindspore import Tensor -from mindspore.ops import operations as P -import mindspore.nn as nn import numpy as np import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter from mindspore.communication.management import init, NCCL_WORLD_COMM_GROUP, get_rank, get_group_size @@ -27,8 +27,9 @@ rank = get_rank() size = get_group_size() x = np.ones([size, 1, 3, 3]).astype(np.float32) * 0.01 * (rank + 1) + class Net(nn.Cell): - def __init__( self): + def __init__(self): super(Net, self).__init__() self.x = Parameter(initializer(Tensor(x), x.shape), name='x') @@ -46,6 +47,7 @@ class Net(nn.Cell): self.reduce_scatter2(self.x), self.reduce_scatter3(self.x)) + def test_ReduceScatter(): reduce_scatter = Net() output = reduce_scatter() @@ -53,7 +55,7 @@ def test_ReduceScatter(): sum = np.ones([size, 1, 3, 3]).astype(np.float32) * 0 for i in range(size): sum += np.ones([size, 1, 3, 3]).astype(np.float32) * 0.01 * (i + 1) - expect0 = sum[rank : rank + 1] + expect0 = sum[rank: rank + 1] diff0 = output[0].asnumpy() - expect0 error0 = np.ones(shape=expect0.shape) * 1.0e-5 assert np.all(diff0 < error0) diff --git a/tests/st/networks/models/alexnet.py b/tests/st/networks/models/alexnet.py index 4c8981f04a..f74d09353c 100644 --- a/tests/st/networks/models/alexnet.py +++ b/tests/st/networks/models/alexnet.py @@ -16,6 +16,7 @@ import mindspore.nn as nn from mindspore.ops import operations as P from mindspore.nn import Dense + class AlexNet(nn.Cell): def __init__(self, num_classes=10): super(AlexNet, self).__init__() diff --git a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py index 5b6268505b..7d30592044 100644 --- a/tests/st/networks/models/bert/bert_tdt_no_lossscale.py +++ b/tests/st/networks/models/bert/bert_tdt_no_lossscale.py @@ -18,21 +18,22 @@ import os import pytest import numpy as np -from numpy import allclose +import mindspore.context as context import mindspore.common.dtype as mstype import mindspore.dataset.engine.datasets as de import mindspore.dataset.transforms.c_transforms as C -from mindspore import context -from mindspore.common.tensor import Tensor +from mindspore import Tensor from mindspore.train.model import Model from mindspore.train.callback import Callback from mindspore.model_zoo.Bert_NEZHA import BertConfig, BertNetworkWithLoss, BertTrainOneStepCell from mindspore.nn.optim import Momentum from mindspore import log as logger + _current_dir = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = ["/home/workspace/mindspore_dataset/bert/example/examples.tfrecord"] SCHEMA_DIR = "/home/workspace/mindspore_dataset/bert/example/datasetSchema.json" + def get_config(version='base', batch_size=1): """get config""" if version == 'base': @@ -99,13 +100,14 @@ def get_config(version='base', batch_size=1): bert_config = BertConfig(batch_size=batch_size) return bert_config + def me_de_train_dataset(): """test me de train dataset""" # apply repeat operations repeat_count = 1 ds = de.TFRecordDataset(DATA_DIR, SCHEMA_DIR, columns_list=["input_ids", "input_mask", "segment_ids", - "next_sentence_labels", "masked_lm_positions", - "masked_lm_ids", "masked_lm_weights"], shuffle=False) + "next_sentence_labels", "masked_lm_positions", + "masked_lm_ids", "masked_lm_weights"], shuffle=False) type_cast_op = C.TypeCast(mstype.int32) ds = ds.map(input_columns="masked_lm_ids", operations=type_cast_op) ds = ds.map(input_columns="masked_lm_positions", operations=type_cast_op) @@ -137,6 +139,7 @@ class ModelCallback(Callback): self.loss_list.append(cb_params.net_outputs.asnumpy()[0]) logger.info("epoch: {}, outputs are {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs))) + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -180,7 +183,8 @@ def test_bert_tdt(): expect_out = [12.19179, 11.965041, 11.969687, 11.97815, 11.969171, 12.603289, 12.165594, 12.824818, 12.38842, 12.604046] logger.info("expected loss value output: {}".format(expect_out)) - assert allclose(loss_value, expect_out, 0.00001, 0.00001) + assert np.allclose(loss_value, expect_out, 0.00001, 0.00001) + if __name__ == '__main__': test_bert_tdt() diff --git a/tests/st/networks/models/lenet.py b/tests/st/networks/models/lenet.py index 9df91822f7..8f6b969cd7 100644 --- a/tests/st/networks/models/lenet.py +++ b/tests/st/networks/models/lenet.py @@ -14,9 +14,10 @@ # ============================================================================ import numpy as np import mindspore.nn as nn +from mindspore import Tensor from mindspore.ops import operations as P from mindspore.nn import Dense -from mindspore import Tensor + class LeNet(nn.Cell): def __init__(self): diff --git a/tests/st/networks/models/resnetv1_5.py b/tests/st/networks/models/resnetv1_5.py index 855aec7014..604389547e 100644 --- a/tests/st/networks/models/resnetv1_5.py +++ b/tests/st/networks/models/resnetv1_5.py @@ -13,9 +13,10 @@ # limitations under the License. # ============================================================================ import numpy as np -from mindspore.common.tensor import Tensor import mindspore.nn as nn -import mindspore.ops.operations as P +from mindspore import Tensor +from mindspore.ops import operations as P + def weight_variable(shape): ones = np.ones(shape).astype(np.float32) @@ -37,7 +38,7 @@ def conv3x3(in_channels, out_channels, stride=1, padding=0): weight_shape = (out_channels, in_channels, 3, 3) weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=3, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") def conv1x1(in_channels, out_channels, stride=1, padding=0): @@ -45,7 +46,7 @@ def conv1x1(in_channels, out_channels, stride=1, padding=0): weight_shape = (out_channels, in_channels, 1, 1) weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=1, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") def conv7x7(in_channels, out_channels, stride=1, padding=0): @@ -53,7 +54,7 @@ def conv7x7(in_channels, out_channels, stride=1, padding=0): weight_shape = (out_channels, in_channels, 7, 7) weight = weight_variable(weight_shape) return nn.Conv2d(in_channels, out_channels, - kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") + kernel_size=7, stride=stride, padding=padding, weight_init=weight, has_bias=False, pad_mode="same") def bn_with_initialize(out_channels): @@ -63,7 +64,7 @@ def bn_with_initialize(out_channels): beta = weight_variable_0(shape) gamma = weight_variable_1(shape) bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, - beta_init=beta, moving_mean_init=mean, moving_var_init=var) + beta_init=beta, moving_mean_init=mean, moving_var_init=var) return bn @@ -74,7 +75,7 @@ def bn_with_initialize_last(out_channels): beta = weight_variable_0(shape) gamma = weight_variable_0(shape) bn = nn.BatchNorm2d(out_channels, momentum=0.1, eps=0.0001, gamma_init=gamma, - beta_init=beta, moving_mean_init=mean, moving_var_init=var) + beta_init=beta, moving_mean_init=mean, moving_var_init=var) return bn @@ -294,6 +295,6 @@ class ResNet(nn.Cell): x = self.fc(x) return x + def resnet50(batch_size, num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes, batch_size) - diff --git a/tests/st/networks/test_cpu_lenet.py b/tests/st/networks/test_cpu_lenet.py index 9fd50f5d9b..7101e29aa9 100644 --- a/tests/st/networks/test_cpu_lenet.py +++ b/tests/st/networks/test_cpu_lenet.py @@ -13,13 +13,15 @@ # limitations under the License. # ============================================================================ import pytest -from mindspore.nn import TrainOneStepCell, WithLossCell -import mindspore.context as context -from mindspore.nn.optim import Momentum import numpy as np +import mindspore.context as context import mindspore.nn as nn -from mindspore.ops import operations as P from mindspore import Tensor +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Momentum +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="CPU") class LeNet(nn.Cell): @@ -52,9 +54,6 @@ class LeNet(nn.Cell): return output -context.set_context(mode=context.GRAPH_MODE, device_target="CPU") - - def train(net, data, label): learning_rate = 0.01 momentum = 0.9 diff --git a/tests/st/networks/test_gpu_alexnet.py b/tests/st/networks/test_gpu_alexnet.py index 9f92fc630e..699617b384 100644 --- a/tests/st/networks/test_gpu_alexnet.py +++ b/tests/st/networks/test_gpu_alexnet.py @@ -19,15 +19,17 @@ from __future__ import print_function import pytest import numpy as np +import mindspore.context as context import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell -from mindspore import Tensor from mindspore.common.initializer import initializer -import mindspore.context as context + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + class AlexNet(nn.Cell): def __init__(self, num_classes=10): super(AlexNet, self).__init__() @@ -66,6 +68,7 @@ class AlexNet(nn.Cell): x = self.fc3(x) return x + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -73,14 +76,14 @@ def test_trainTensor(num_classes=10, epoch=15, batch_size=32): net = AlexNet(num_classes) lr = 0.1 momentum = 0.9 - optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum, weight_decay = 0.0001) + optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, momentum, weight_decay=0.0001) criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) net_with_criterion = WithLossCell(net, criterion) train_network = TrainOneStepCell(net_with_criterion, optimizer) train_network.set_train() - losses=[] + losses = [] for i in range(0, epoch): - data = Tensor(np.ones([batch_size, 3 ,227, 227]).astype(np.float32) * 0.01) + data = Tensor(np.ones([batch_size, 3, 227, 227]).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) loss = train_network(data, label) losses.append(loss) diff --git a/tests/st/networks/test_gpu_lenet.py b/tests/st/networks/test_gpu_lenet.py index 4dac2247d0..b6b94cd23d 100644 --- a/tests/st/networks/test_gpu_lenet.py +++ b/tests/st/networks/test_gpu_lenet.py @@ -16,16 +16,19 @@ import pytest import numpy as np import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import Dense -from mindspore import Tensor from mindspore.common.initializer import initializer from mindspore.common import dtype as mstype -import mindspore.context as context + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + class LeNet(nn.Cell): def __init__(self): super(LeNet, self).__init__() @@ -65,6 +68,7 @@ def multisteplr(total_steps, gap, base_lr=0.9, gamma=0.1, dtype=mstype.float32): lr.append(lr_) return Tensor(np.array(lr), dtype) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -81,7 +85,7 @@ def test_train_lenet(): train_network.set_train() losses = [] for i in range(epoch): - data = Tensor(np.ones([net.batch_size, 3 ,32, 32]).astype(np.float32) * 0.01) + data = Tensor(np.ones([net.batch_size, 3, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([net.batch_size]).astype(np.int32)) loss = train_network(data, label) losses.append(loss) diff --git a/tests/st/networks/test_gpu_lstm.py b/tests/st/networks/test_gpu_lstm.py index e5208ff669..acf5ca9396 100644 --- a/tests/st/networks/test_gpu_lstm.py +++ b/tests/st/networks/test_gpu_lstm.py @@ -15,18 +15,20 @@ import pytest import numpy as np +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import Dense -from mindspore import Tensor from mindspore.common.initializer import initializer from mindspore.common.parameter import Parameter -import mindspore.context as context -import mindspore.nn as nn + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + def InitialLstmWeight(input_size, hidden_size, num_layers, bidirectional, has_bias=False): num_directions = 1 if bidirectional: @@ -56,6 +58,7 @@ def InitialLstmWeight(input_size, hidden_size, num_layers, bidirectional, has_bi return h, c, w + class SentimentNet(nn.Cell): def __init__(self, vocab_size, embed_size, num_hiddens, num_layers, bidirectional, weight, labels, batch_size): @@ -99,6 +102,7 @@ class SentimentNet(nn.Cell): outputs = self.decoder(encoding) return outputs + batch_size = 64 @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @@ -130,7 +134,7 @@ def test_LSTM(): train_network.set_train() train_features = Tensor(np.ones([64, max_len]).astype(np.int32)) - train_labels = Tensor(np.ones([64,]).astype(np.int32)[0:64]) + train_labels = Tensor(np.ones([64, ]).astype(np.int32)[0:64]) losses = [] for epoch in range(num_epochs): loss = train_network(train_features, train_labels) diff --git a/tests/st/networks/test_gpu_resnet.py b/tests/st/networks/test_gpu_resnet.py index 6d8337a6a9..a5f450d5e3 100644 --- a/tests/st/networks/test_gpu_resnet.py +++ b/tests/st/networks/test_gpu_resnet.py @@ -19,36 +19,34 @@ from __future__ import print_function import pytest import numpy as np - +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor from mindspore.nn.cell import Cell from mindspore.nn.layer.conv import Conv2d from mindspore.nn.layer.basic import Flatten from mindspore.nn.layer.normalization import BatchNorm2d from mindspore.nn.layer.pooling import MaxPool2d from mindspore.ops.operations import TensorAdd -import mindspore.nn as nn - from mindspore.nn.optim import Momentum from mindspore.ops import operations as P from mindspore.nn import TrainOneStepCell, WithLossCell from mindspore.nn import Dense -from mindspore import Tensor from mindspore.common.initializer import initializer -import mindspore.context as context - context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + def random_normal_init(shape, mean=0.0, stddev=0.01, seed=None): init_value = np.ones(shape).astype(np.float32) * 0.01 return Tensor(init_value) + def variance_scaling_raw(shape): variance_scaling_value = np.ones(shape).astype(np.float32) * 0.01 return Tensor(variance_scaling_value) - def weight_variable_0(shape): zeros = np.zeros(shape).astype(np.float32) return Tensor(zeros) @@ -323,6 +321,7 @@ class ResNet(Cell): def resnet50(num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes) + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -335,9 +334,9 @@ def test_trainTensor(num_classes=10, epoch=8, batch_size=1): net_with_criterion = WithLossCell(net, criterion) train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer train_network.set_train() - losses=[] + losses = [] for i in range(0, epoch): - data = Tensor(np.ones([batch_size, 3 ,224, 224]).astype(np.float32) * 0.01) + data = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32) * 0.01) label = Tensor(np.ones([batch_size]).astype(np.int32)) loss = train_network(data, label) losses.append(loss) diff --git a/tests/st/networks/test_network_main.py b/tests/st/networks/test_network_main.py index 7601739f8c..79bd46d87a 100644 --- a/tests/st/networks/test_network_main.py +++ b/tests/st/networks/test_network_main.py @@ -13,25 +13,27 @@ # limitations under the License. # ============================================================================ """ -Function: +Function: test network -Usage: +Usage: python test_network_main.py --net lenet --target Ascend """ import os import time import numpy as np import argparse +import mindspore.context as context import mindspore.nn as nn -from mindspore.common.tensor import Tensor +from mindspore import Tensor from mindspore.nn import TrainOneStepCell, WithLossCell -import mindspore.context as context from mindspore.nn.optim import Momentum from models.lenet import LeNet from models.resnetv1_5 import resnet50 from models.alexnet import AlexNet + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + def train(net, data, label): learning_rate = 0.01 momentum = 0.9 @@ -42,29 +44,31 @@ def train(net, data, label): train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer train_network.set_train() res = train_network(data, label) - print("+++++++++Loss+++++++++++++") print(res) - print("+++++++++++++++++++++++++++") assert res + def test_resnet50(): - data = Tensor(np.ones([32, 3 ,224, 224]).astype(np.float32) * 0.01) + data = Tensor(np.ones([32, 3, 224, 224]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = resnet50(32, 10) train(net, data, label) + def test_lenet(): - data = Tensor(np.ones([32, 1 ,32, 32]).astype(np.float32) * 0.01) + data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = LeNet() train(net, data, label) + def test_alexnet(): - data = Tensor(np.ones([32, 3 ,227, 227]).astype(np.float32) * 0.01) + data = Tensor(np.ones([32, 3, 227, 227]).astype(np.float32) * 0.01) label = Tensor(np.ones([32]).astype(np.int32)) net = AlexNet() train(net, data, label) + parser = argparse.ArgumentParser(description='MindSpore Testing Network') parser.add_argument('--net', default='resnet50', type=str, help='net name') parser.add_argument('--device', default='Ascend', type=str, help='device target') diff --git a/tests/st/pynative/test_ascend_lenet.py b/tests/st/pynative/test_ascend_lenet.py index 4009844791..5a84aaf930 100644 --- a/tests/st/pynative/test_ascend_lenet.py +++ b/tests/st/pynative/test_ascend_lenet.py @@ -14,7 +14,8 @@ # ============================================================================ import pytest import numpy as np -import time, math +import time +import math import mindspore.nn as nn from mindspore import context, Tensor, ParameterTuple from mindspore.ops import operations as P @@ -28,6 +29,7 @@ from mindspore.nn.optim import Momentum np.random.seed(1) + def weight_variable(): """weight initial""" return TruncatedNormal(0.02) @@ -58,6 +60,7 @@ class LeNet(nn.Cell): Examples: >>> LeNet(num_class=10) """ + def __init__(self, num_class=10): super(LeNet, self).__init__() self.num_class = num_class @@ -91,6 +94,7 @@ class CrossEntropyLoss(nn.Cell): """ Define loss for network """ + def __init__(self): super(CrossEntropyLoss, self).__init__() self.cross_entropy = P.SoftmaxCrossEntropyWithLogits() @@ -111,6 +115,7 @@ class GradWrap(nn.Cell): """ GradWrap definition """ + def __init__(self, network): super(GradWrap, self).__init__() self.network = network @@ -154,4 +159,3 @@ def test_ascend_pynative_lenet(): print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time) assert(loss_output.asnumpy() < 0.1) - \ No newline at end of file diff --git a/tests/st/summary/test_davinci_summary.py b/tests/st/summary/test_davinci_summary.py index 1611ca8ec7..a2ed840515 100644 --- a/tests/st/summary/test_davinci_summary.py +++ b/tests/st/summary/test_davinci_summary.py @@ -33,10 +33,12 @@ SUMMARY_DIR = CUR_DIR + "/test_temp_summary_event_file/" context.set_context(device_target="Ascend") + class MsWrapper(nn.Cell): def __init__(self, network): super(MsWrapper, self).__init__(auto_prefix=False) self._network = network + @ms_function def construct(self, *args): return self._network(*args) @@ -45,14 +47,15 @@ class MsWrapper(nn.Cell): def me_train_tensor(net, input_np, label_np, epoch_size=2): context.set_context(mode=context.GRAPH_MODE) loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) - opt = ApplyMomentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), filter(lambda x: x.requires_grad, net.get_parameters())) + opt = ApplyMomentum(Tensor(np.array([0.1])), Tensor(np.array([0.9])), + filter(lambda x: x.requires_grad, net.get_parameters())) Model(net, loss, opt) _network = wrap.WithLossCell(net, loss) _train_net = MsWrapper(wrap.TrainOneStepCell(_network, opt)) _train_net.set_train() summary_writer = SummaryRecord(SUMMARY_DIR, file_suffix="_MS_GRAPH", network=_train_net) for epoch in range(0, epoch_size): - print(f"epoch %d"%(epoch)) + print(f"epoch %d" % (epoch)) output = _train_net(Tensor(input_np), Tensor(label_np)) summary_writer.record(i) print("********output***********") diff --git a/tests/st/summary/test_gpu_summary.py b/tests/st/summary/test_gpu_summary.py index c97c08c4e1..e8eadc66ab 100644 --- a/tests/st/summary/test_gpu_summary.py +++ b/tests/st/summary/test_gpu_summary.py @@ -108,6 +108,6 @@ def me_scalar_summary(steps, tag=None, value=None): def test_scalarsummary_scalar1_step10_summaryrecord1(): clean_environment_file(SUMMARY_DIR_ME_TEMP) output_dict = me_scalar_summary(10) - print("test_scalarsummary_scalar1_step10_summaryrecord1 \n",output_dict) + print("test_scalarsummary_scalar1_step10_summaryrecord1 \n", output_dict) save_summary_events_file(SUMMARY_DIR_ME_TEMP, SUMMARY_DIR_ME) clean_environment_file(SUMMARY_DIR_ME) diff --git a/tests/st/tbe_networks/export_geir.py b/tests/st/tbe_networks/export_geir.py index 467388c5e8..a4368e6320 100644 --- a/tests/st/tbe_networks/export_geir.py +++ b/tests/st/tbe_networks/export_geir.py @@ -24,12 +24,13 @@ import mindspore.nn as nn from mindspore import context from mindspore.train.serialization import save, load, save_checkpoint, load_checkpoint,\ - load_param_into_net, _exec_save_checkpoint,\ - _check_filedir_or_create, _chg_model_file_name_if_same_exist, \ - _read_file_last_line, context, export + load_param_into_net, _exec_save_checkpoint,\ + _check_filedir_or_create, _chg_model_file_name_if_same_exist, \ + _read_file_last_line, context, export + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", + enable_task_sink=True, enable_loop_sink=True, enable_ir_fusion=True) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", -enable_task_sink=True,enable_loop_sink=True,enable_ir_fusion=True) def test_resnet50_export(batch_size=1, num_classes=5): context.set_context(enable_ir_fusion=False) diff --git a/tests/st/tbe_networks/resnet.py b/tests/st/tbe_networks/resnet.py index 2024286b8f..4f2ff79a86 100644 --- a/tests/st/tbe_networks/resnet.py +++ b/tests/st/tbe_networks/resnet.py @@ -19,6 +19,7 @@ from mindspore.ops import operations as P from mindspore.common.initializer import initializer from mindspore.common import dtype as mstype + def weight_variable(shape): return initializer('XavierUniform', shape=shape, dtype=mstype.float32) @@ -297,4 +298,3 @@ class ResNet(nn.Cell): def resnet50(batch_size, num_classes): return ResNet(ResidualBlock, [3, 4, 6, 3], num_classes, batch_size) - diff --git a/tests/st/tbe_networks/resnet_cifar.py b/tests/st/tbe_networks/resnet_cifar.py index f1ab02afa3..7bd03f5d81 100644 --- a/tests/st/tbe_networks/resnet_cifar.py +++ b/tests/st/tbe_networks/resnet_cifar.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +import argparse import mindspore.nn as nn from mindspore import Tensor from mindspore.ops import operations as P @@ -35,7 +36,6 @@ random.seed(1) np.random.seed(1) ds.config.set_seed(1) -import argparse parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') parser.add_argument('--device_num', type=int, default=1, help='Device num.') @@ -48,15 +48,16 @@ parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoin parser.add_argument('--dataset_path', type=str, default="/var/log/npu/datasets/cifar", help='Dataset path') args_opt = parser.parse_args() -device_id=int(os.getenv('DEVICE_ID')) +device_id = int(os.getenv('DEVICE_ID')) -data_home=args_opt.dataset_path +data_home = args_opt.dataset_path context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(enable_task_sink=True, device_id=device_id) context.set_context(enable_loop_sink=True) context.set_context(enable_mem_reuse=True) + def create_dataset(repeat_num=1, training=True): data_dir = data_home + "/cifar-10-batches-bin" if not training: @@ -64,8 +65,8 @@ def create_dataset(repeat_num=1, training=True): data_set = ds.Cifar10Dataset(data_dir) if args_opt.run_distribute: - rank_id=int(os.getenv('RANK_ID')) - rank_size=int(os.getenv('RANK_SIZE')) + rank_id = int(os.getenv('RANK_ID')) + rank_size = int(os.getenv('RANK_SIZE')) data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) resize_height = 224 @@ -74,9 +75,9 @@ def create_dataset(repeat_num=1, training=True): shift = 0.0 # define map operations - random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT + random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() - resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR + resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR rescale_op = vision.Rescale(rescale, shift) normalize_op = vision.Normalize((0.4465, 0.4822, 0.4914), (0.2010, 0.1994, 0.2023)) changeswap_op = vision.HWC2CHW() @@ -103,6 +104,7 @@ def create_dataset(repeat_num=1, training=True): return data_set + class CrossEntropyLoss(nn.Cell): def __init__(self): super(CrossEntropyLoss, self).__init__() diff --git a/tests/st/tbe_networks/test_resnet_cifar_8p.py b/tests/st/tbe_networks/test_resnet_cifar_8p.py index 6e83f4180e..69f0a80d12 100644 --- a/tests/st/tbe_networks/test_resnet_cifar_8p.py +++ b/tests/st/tbe_networks/test_resnet_cifar_8p.py @@ -112,6 +112,7 @@ class CrossEntropyLoss(nn.Cell): loss = self.mean(loss, (-1,)) return loss + class LossGet(Callback): def __init__(self, per_print_times=1): super(LossGet, self).__init__() @@ -143,6 +144,7 @@ class LossGet(Callback): def get_loss(self): return self._loss + def train_process(q, device_id, epoch_size, num_classes, device_num, batch_size, enable_hccl): os.system("mkdir " + str(device_id)) os.chdir(str(device_id)) From c78630d7370b7b11e2358a4fe2053708c07981f6 Mon Sep 17 00:00:00 2001 From: lichenever Date: Mon, 20 Apr 2020 21:30:42 +0800 Subject: [PATCH 105/142] support multiple subgraphs --- .../allreduce_fusion/allreduce_fusion.cc | 7 +- mindspore/ccsrc/parallel/step_parallel.cc | 363 +++++++++--------- mindspore/ccsrc/parallel/step_parallel.h | 5 +- .../parallel/test_semi_auto_two_subgraphs.py | 108 ++++++ 4 files changed, 306 insertions(+), 177 deletions(-) create mode 100644 tests/ut/python/parallel/test_semi_auto_two_subgraphs.py diff --git a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc index b4f4cb5b22..30173e533c 100644 --- a/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/parallel/allreduce_fusion/allreduce_fusion.cc @@ -399,7 +399,12 @@ Status AllreduceFusion::ProcessAllreduceFusion(const CNodePtr &ret) { ret_ = ret; root_graph_ = ret_->func_graph(); MS_EXCEPTION_IF_NULL(root_graph_); - auto forward_graph = ForwardGraph(root_graph_); + auto graph_set = ForwardGraph(root_graph_); + if (graph_set.size() > 1) { + MS_LOG(WARNING) << "AllReduce fusion don't support multiple subgraphs now."; + return SUCCESS; + } + auto forward_graph = *(graph_set.begin()); MS_EXCEPTION_IF_NULL(forward_graph); forward_ret_ = forward_graph->get_return(); MS_EXCEPTION_IF_NULL(forward_ret_); diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index d1390db899..c24c14abf6 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1607,72 +1607,79 @@ void ReshapeInit(const std::vector &all_nodes) { } } -// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) -bool IsGradSensNode(const AnfNodePtr &node) { - if (!node->isa()) { - return false; +CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + CNodePtr return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->size() < 2) { + MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; } + AnfNodePtr pre_node = return_node->input(1); + MS_EXCEPTION_IF_NULL(pre_node); - // cnode(sens)-->cnode(tuple_getitem) - auto cnode = node->cast(); - AnfNodePtr expect_tuple_getitem = cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem); - if (!expect_tuple_getitem->isa()) { - return false; - } - auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); - if (!IsValueNode(expect_tuple_getitem_cnode->input(0))) { - return false; + auto pre_cnode = pre_node->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + auto current_prim = GetValueNode(pre_cnode->input(0)); + + // return -> cast + if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + pre_cnode = pre_cnode->input(1)->cast(); + MS_EXCEPTION_IF_NULL(pre_cnode); + current_prim = GetValueNode(pre_cnode->input(0)); } - ValueNodePtr expect_tuple_getitem_value_node = expect_tuple_getitem_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_value_node); - PrimitivePtr expect_tuple_getitem_prim = expect_tuple_getitem_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_tuple_getitem_prim); - if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { - return false; + + // notice: the GetNext op has not input + if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { + MS_LOG(INFO) << "The loss is: " << current_prim->name(); + return pre_cnode; } - // cnode(sens)-->cnode(tuple_getitem)-->cnode - AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); - MS_EXCEPTION_IF_NULL(expect_anonymous); - if (!expect_anonymous->isa()) { - return false; + // size of common cnode is larger than 1 + if (pre_cnode->size() < 2) { + MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; } - // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) - auto expect_anonymous_cnode = expect_anonymous->cast(); - MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); - AnfNodePtr expect_j = expect_anonymous_cnode->input(0); - MS_EXCEPTION_IF_NULL(expect_j); - if (!expect_j->isa()) { - return false; + // return -> tuple_getitem -> loss + if (current_prim->name() == TUPLE_GETITEM) { + AnfNodePtr pre_pre_node = pre_cnode->input(1); + MS_EXCEPTION_IF_NULL(pre_pre_node); + + auto pre_pre_cnode = pre_pre_node->cast(); + auto value = pre_pre_cnode->input(0)->cast(); + MS_EXCEPTION_IF_NULL(value); + PrimitivePtr prim = value->value()->cast(); + MS_EXCEPTION_IF_NULL(prim); + MS_LOG(DEBUG) << "The loss name is " << prim->name(); + return pre_pre_cnode; } - auto expect_j_cnode = expect_j->cast(); - MS_EXCEPTION_IF_NULL(expect_j_cnode); - if (!IsValueNode(expect_j_cnode->input(0))) { - return false; + + // return -> make_tuple + if (current_prim->name() == MAKE_TUPLE) { + MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; } - ValueNodePtr expect_j_value_node = expect_j_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_j_value_node); - PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_j_prim); - return (expect_j_prim->name() == J); + + // return -> loss + MS_LOG(DEBUG) << "The loss name is " << current_prim->name(); + return pre_cnode; } -TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { +TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &cnode) { + MS_EXCEPTION_IF_NULL(cnode); + TensorLayouts ret; + if (!IsValueNode(cnode->input(1))) { + MS_LOG(EXCEPTION) << "Sens can't find the corresponding graph."; + } + auto func_graph = GetValueNode(cnode->input(1)); + auto loss_cnode = FindLossCNode(func_graph); MS_EXCEPTION_IF_NULL(loss_cnode); AnfNodePtr node = loss_cnode->cast(); MS_EXCEPTION_IF_NULL(node); LossNodeInfo node_info = GetLossNodeInfo(node); - ValueNodePtr prim_anf_node = loss_cnode->input(0)->cast(); MS_EXCEPTION_IF_NULL(prim_anf_node); PrimitivePtr prim = prim_anf_node->value()->cast(); MS_EXCEPTION_IF_NULL(prim); - - TensorLayouts ret; if (INVALID_LOSS_OPS.find(prim->name()) != INVALID_LOSS_OPS.end()) { MS_LOG(WARNING) << "The loss name is: " << prim->name() << ", do nothing for split sens now"; return ret; @@ -1680,7 +1687,6 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { OperatorInfoPtr operator_info = loss_cnode->operator_info(); MS_EXCEPTION_IF_NULL(operator_info); - TensorInfo loss_grad_tensor_info; size_t op_output_size = operator_info->outputs_tensor_info().size(); MS_LOG(INFO) << "The loss name is " << operator_info->name() << ", the has tuple item is " @@ -1805,6 +1811,100 @@ void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePt HandleDropoutNode(distribute_operator, cnode); } +std::set FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { + // J->CNode->Graph + std::set graph_set; + for (auto &node : root_all_nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if ((cnode->size() < 2) || !IsValueNode(cnode->input(0))) { + continue; + } + auto expect_j_prim = GetValueNode(cnode->input(0)); + if (expect_j_prim->name() != J) { + continue; + } + if (IsValueNode(cnode->input(1))) { + auto graph = GetValueNode(cnode->input(1)); + MS_LOG(DEBUG) << "Find the forward graph success"; + graph_set.insert(graph); + } + } + return graph_set; +} + +// Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) +void StepSplitSens(const AnfNodePtr &node) { + if (!node->isa()) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem) + auto cnode = node->cast(); + AnfNodePtr expect_tuple_getitem = cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem); + if (!expect_tuple_getitem->isa()) { + return; + } + auto expect_tuple_getitem_cnode = expect_tuple_getitem->cast(); + MS_EXCEPTION_IF_NULL(expect_tuple_getitem_cnode); + if (!IsValueNode(expect_tuple_getitem_cnode->input(0))) { + return; + } + auto expect_tuple_getitem_prim = GetValueNode(expect_tuple_getitem_cnode->input(0)); + if (expect_tuple_getitem_prim->name() != TUPLE_GETITEM) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode + AnfNodePtr expect_anonymous = expect_tuple_getitem_cnode->input(1); + MS_EXCEPTION_IF_NULL(expect_anonymous); + if (!expect_anonymous->isa()) { + return; + } + + // cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) + auto expect_anonymous_cnode = expect_anonymous->cast(); + MS_EXCEPTION_IF_NULL(expect_anonymous_cnode); + AnfNodePtr expect_j = expect_anonymous_cnode->input(0); + MS_EXCEPTION_IF_NULL(expect_j); + if (!expect_j->isa()) { + return; + } + auto expect_j_cnode = expect_j->cast(); + MS_EXCEPTION_IF_NULL(expect_j_cnode); + if (!IsValueNode(expect_j_cnode->input(0))) { + return; + } + auto expect_j_prim = GetValueNode(expect_j_cnode->input(0)); + if (expect_j_prim->name() == J) { + auto loss_grad_layout = GetLossNodeGradOutputLayout(expect_j_cnode); + if (!loss_grad_layout.empty()) { + SplitSens(node, loss_grad_layout[0]); + } + } +} + +std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + AnfNodePtr root_return_node = root->get_return(); + MS_EXCEPTION_IF_NULL(root_return_node); + std::vector loss_node; + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + if (graph_set.empty()) { + loss_node.push_back(FindLossCNode(root)); + } + (void)std::transform(graph_set.begin(), graph_set.end(), std::back_inserter(loss_node), + [](const FuncGraphPtr &graph) { return FindLossCNode(graph); }); + return loss_node; +} + void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(root); @@ -1812,18 +1912,15 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector loss_cnode = FindLossCNodeFromRoot(root); + // split sens must before inserting the operators. for (auto &node : all_nodes) { - // find sens node - if ((grad_sens_node == nullptr) && IsGradSensNode(node)) { - grad_sens_node = node; - MS_LOG(INFO) << "Find the sens node success"; - } + // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. + // If the type of sens node is not Tensor, it is unsupported now, do nothing default. + StepSplitSens(node); + } + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (node->isa()) { auto cnode = node->cast(); @@ -1837,7 +1934,8 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vectorget_return(); - MS_EXCEPTION_IF_NULL(return_node); - if (return_node->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "Failure: " << return_node->ToString() << " size is smaller than 2"; - } - AnfNodePtr pre_node = return_node->input(1); - MS_EXCEPTION_IF_NULL(pre_node); - - auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - auto current_value = pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(current_value); - PrimitivePtr current_prim = current_value->value()->cast(); - MS_EXCEPTION_IF_NULL(current_prim); - - // return -> cast - if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_cnode = pre_cnode->input(1)->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - current_prim = GetValueNode(pre_cnode->input(0)); - } - - // notice: the GetNext op has not input - if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { - MS_LOG(INFO) << "The loss is: " << current_prim->name(); - return pre_cnode; - } - - // size of common cnode is larger than 1 - if (pre_cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << pre_cnode->ToString() << " size( " << pre_cnode->inputs().size() << " ) is smaller than 2"; - } - - // return -> tuple_getitem -> loss - if (current_prim->name() == TUPLE_GETITEM) { - AnfNodePtr pre_pre_node = pre_cnode->input(1); - MS_EXCEPTION_IF_NULL(pre_pre_node); - - auto pre_pre_cnode = pre_pre_node->cast(); - auto value = pre_pre_cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(value); - PrimitivePtr prim = value->value()->cast(); - MS_EXCEPTION_IF_NULL(prim); - MS_LOG(INFO) << "The loss name is " << prim->name(); - return pre_pre_cnode; - } else if (current_prim->name() == MAKE_TUPLE) { - MS_LOG(EXCEPTION) << "The loss have make_tuple, it is not supported"; - } - - // return -> loss - MS_LOG(INFO) << "The loss name is " << current_prim->name(); - return pre_cnode; +std::set ForwardGraph(const FuncGraphPtr &root) { + MS_EXCEPTION_IF_NULL(root); + const auto &all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); + return graph_set; } -FuncGraphPtr FindForwardGraphByRootNodes(const AnfNodeSet &root_all_nodes) { - for (auto &node : root_all_nodes) { +std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { + MS_EXCEPTION_IF_NULL(graph); + auto loss_cnode = FindLossCNode(graph); + MS_EXCEPTION_IF_NULL(loss_cnode); + auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); + std::vector root_forward_nodes; + for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { continue; } - auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - if ((cnode->inputs().size() < 2) || !IsValueNode(cnode->input(0))) { - continue; - } - ValueNodePtr expect_j_value_node = cnode->input(0)->cast(); - MS_EXCEPTION_IF_NULL(expect_j_value_node); - PrimitivePtr expect_j_prim = expect_j_value_node->value()->cast(); - MS_EXCEPTION_IF_NULL(expect_j_prim); - if (expect_j_prim->name() != J) { - continue; - } - MS_LOG(DEBUG) << "Find J prim: " << expect_j_value_node->DebugString() << "."; - if (IsValueNode(cnode->input(1))) { - auto graph = GetValueNode(cnode->input(1)); - MS_LOG(INFO) << "Find the forward graph success"; - return graph; + auto root_node_id = node->UniqueIdThroughCopy(); + if (loss_cnode_id == root_node_id) { + root_forward_nodes = DeepLinkedGraphSearch(cnode); + break; } } - return nullptr; -} - -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - const auto &all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); - if (func_graph == nullptr) { - return FindLossCNode(root); - } else { - return FindLossCNode(func_graph); - } -} - -FuncGraphPtr ForwardGraph(const FuncGraphPtr &root) { - FuncGraphPtr forward_graph = root; - MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - const auto &all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); - if (func_graph != nullptr) { - forward_graph = func_graph; - } - return forward_graph; + return root_forward_nodes; } void MarkForwardCNode(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - auto &all_nodes = root->nodes(); - FuncGraphPtr func_graph = FindForwardGraphByRootNodes(all_nodes); + auto all_nodes = root->nodes(); + std::set graph_set = FindForwardGraphByRootNodes(all_nodes); - if (func_graph == nullptr) { - // Can not find the forward graph, so the ops in root graph are forward. + if (graph_set.empty()) { MS_LOG(INFO) << "Can not find the forward graph, so mark the ops in root graph"; SetForwardFlag(all_nodes); } else { - MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); - AnfNodePtr return_node = func_graph->get_return(); - MS_EXCEPTION_IF_NULL(return_node); - std::vector all_dfs_nodes = DeepLinkedGraphSearch(return_node); - SetForwardFlag(all_dfs_nodes); + for (auto &func_graph : graph_set) { + MS_LOG(INFO) << "The sub graph size of root is " << root->func_graphs_used().size(); + auto return_node = func_graph->get_return(); + MS_EXCEPTION_IF_NULL(return_node); + auto all_dfs_nodes = DeepLinkedGraphSearch(return_node); + SetForwardFlag(all_dfs_nodes); + auto root_forward_nodes = FindRootForwardCNode(func_graph, all_nodes); + if (root_forward_nodes.empty()) { + continue; + } + // Mark forward flag for the nodes in root graph. + SetForwardFlag(root_forward_nodes); + } } } diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index 184d11d173..b0d128f515 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -24,6 +24,7 @@ #include #include #include +#include #include "./common.h" #include "optimizer/opt.h" @@ -142,13 +143,13 @@ bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optim int32_t GetTupleGetItemIndex(const CNodePtr &cnode); -CNodePtr FindLossCNodeFromRoot(const FuncGraphPtr &root); +std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root); Status ParallelInit(); std::vector ExtractInputsTensorName(const CNodePtr &node); -FuncGraphPtr ForwardGraph(const FuncGraphPtr &root); +std::set ForwardGraph(const FuncGraphPtr &root); } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py b/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py new file mode 100644 index 0000000000..b572968a4f --- /dev/null +++ b/tests/ut/python/parallel/test_semi_auto_two_subgraphs.py @@ -0,0 +1,108 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import mindspore as ms +from mindspore import Tensor, Parameter, ParameterTuple, context +from mindspore import nn +from mindspore.common.api import _executor +from mindspore.nn.optim import Adam, FTRL +from mindspore.ops import operations as P +from mindspore.ops import composite as C +from mindspore.ops import functional as F +import numpy as np + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mul = P.Mul() + self.relu = P.ReLU() + self.param1 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="wide") + self.param2 = Parameter(Tensor(np.ones([8, 8, 8, 8]).astype(np.float32)), name="deep") + + def construct(self, x): + out = self.mul(x, self.param1) + out = self.mul(out, self.param2) + out = self.relu(out) + return out + + +class NetWithLoss(nn.Cell): + def __init__(self, network): + super(NetWithLoss, self).__init__() + self.sum = P.ReduceSum(keep_dims=False).set_strategy(strategy=((4, 1, 1, 1),)) + self.mean = P.ReduceMean(keep_dims=False).set_strategy(strategy=((8, 1, 1, 1),)) + self.net = network + + def construct(self, x): + net_out = self.net(x) + loss1 = self.sum(net_out, -1) + loss2 = self.mean(net_out, -1) + return loss1, loss2 + + +class IthOutputCell(nn.Cell): + def __init__(self, network, output_index): + super(IthOutputCell, self).__init__() + self.network = network + self.output_index = output_index + + def construct(self, x1): + predict = self.network(x1)[self.output_index] + return predict + + +class TrainStepWrap(nn.Cell): + def __init__(self, network, sens=1000.0): + super(TrainStepWrap, self).__init__() + self.network = network + self.network.set_train() + self.trainable_params = network.trainable_params() + weights_w = [] + weights_d = [] + for params in self.trainable_params: + weights_w.append(params) + weights_d.append(params) + + self.weights_w = ParameterTuple(weights_w) + self.weights_d = ParameterTuple(weights_d) + self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, + l1=1e-8, l2=1e-8, initial_accum=1.0) + self.optimizer_d = Adam(self.weights_d, learning_rate=3.5e-4, eps=1e-8, + loss_scale=sens) + self.hyper_map = C.HyperMap() + self.grad_w = C.GradOperation('grad_w', get_by_list=True, + sens_param=True) + self.grad_d = C.GradOperation('grad_d', get_by_list=True, + sens_param=True) + self.sens = sens + self.loss_net_w = IthOutputCell(network, output_index=0) + self.loss_net_d = IthOutputCell(network, output_index=1) + + def construct(self, x): + weights_w = self.weights_w + weights_d = self.weights_d + loss_w, loss_d = self.network(x) + sens_w = P.Fill()(P.DType()(loss_w), P.Shape()(loss_w), self.sens) + sens_d = P.Fill()(P.DType()(loss_d), P.Shape()(loss_d), self.sens) + grads_w = self.grad_w(self.loss_net_w, weights_w)(x, sens_w) + grads_d = self.grad_d(self.loss_net_d, weights_d)(x, sens_d) + return F.depend(loss_w, self.optimizer_w(grads_w)), F.depend(loss_d, self.optimizer_d(grads_d)) + + +def test_two_subgraphs(): + context.set_auto_parallel_context(device_num=8, global_rank=0, parallel_mode="semi_auto_parallel") + net = TrainStepWrap(NetWithLoss(Net())) + input_x = Tensor(np.ones([8, 8, 8, 8]), dtype=ms.float32) + _executor.compile(net, input_x) From 313018015b78ca6dbfd2042c41ae65dcf0a81016 Mon Sep 17 00:00:00 2001 From: laiyongqiang Date: Wed, 22 Apr 2020 18:45:06 +0800 Subject: [PATCH 106/142] fix node check bug in convert_tuple_output_to_maketuple pass --- .../pre_activate/pass/convert_tuple_output_to_maketuple.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc index 3f283e5d24..93c1b73038 100644 --- a/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc +++ b/mindspore/ccsrc/pre_activate/pass/convert_tuple_output_to_maketuple.cc @@ -68,9 +68,8 @@ const AnfNodePtr ConvertTupleOutputToMaketuple::Process(const FuncGraphPtr &func if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimTupleGetItem->name()) { return nullptr; } - if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), [](const AnfNodePtr &node) { - return AnfAlgo::IsTupleOutput(node) && AnfAlgo::GetCNodeName(node) != prim::kPrimMakeTuple->name(); - })) { + if (std::any_of(cnode->inputs().begin() + 1, cnode->inputs().end(), + [](const AnfNodePtr &node) { return AnfAlgo::IsRealKernel(node) && AnfAlgo::IsTupleOutput(node); })) { return ConvertTupleInputToMakeTuple(func_graph, cnode); } return nullptr; From b31946750bcc5bcc597c1bf1202e06e99e964115 Mon Sep 17 00:00:00 2001 From: panfengfeng Date: Wed, 22 Apr 2020 19:19:57 +0800 Subject: [PATCH 107/142] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20?= =?UTF-8?q?!182=20:=20Tuning=20mindrecord=20writer=20performance'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- example/convert_to_mindrecord/README.md | 46 ----- .../imagenet/__init__.py | 0 .../convert_to_mindrecord/imagenet/mr_api.py | 122 ------------ example/convert_to_mindrecord/run_imagenet.sh | 8 - example/convert_to_mindrecord/run_template.sh | 6 - .../template/__init__.py | 0 .../convert_to_mindrecord/template/mr_api.py | 73 ------- example/convert_to_mindrecord/writer.py | 149 -------------- .../ccsrc/mindrecord/common/shard_pybind.cc | 9 +- .../ccsrc/mindrecord/include/shard_header.h | 4 - .../ccsrc/mindrecord/include/shard_writer.h | 37 +--- .../mindrecord/io/shard_index_generator.cc | 3 - mindspore/ccsrc/mindrecord/io/shard_writer.cc | 188 ++---------------- .../ccsrc/mindrecord/meta/shard_header.cc | 38 ---- mindspore/mindrecord/filewriter.py | 15 +- mindspore/mindrecord/shardwriter.py | 5 +- 16 files changed, 35 insertions(+), 668 deletions(-) delete mode 100644 example/convert_to_mindrecord/README.md delete mode 100644 example/convert_to_mindrecord/imagenet/__init__.py delete mode 100644 example/convert_to_mindrecord/imagenet/mr_api.py delete mode 100644 example/convert_to_mindrecord/run_imagenet.sh delete mode 100644 example/convert_to_mindrecord/run_template.sh delete mode 100644 example/convert_to_mindrecord/template/__init__.py delete mode 100644 example/convert_to_mindrecord/template/mr_api.py delete mode 100644 example/convert_to_mindrecord/writer.py diff --git a/example/convert_to_mindrecord/README.md b/example/convert_to_mindrecord/README.md deleted file mode 100644 index 8d3b25e311..0000000000 --- a/example/convert_to_mindrecord/README.md +++ /dev/null @@ -1,46 +0,0 @@ -# MindRecord generating guidelines - - - -- [MindRecord generating guidelines](#mindrecord-generating-guidelines) - - [Create work space](#create-work-space) - - [Implement data generator](#implement-data-generator) - - [Run data generator](#run-data-generator) - - - -## Create work space - -Assume the dataset name is 'xyz' -* Create work space from template - ```shell - cd ${your_mindspore_home}/example/convert_to_mindrecord - cp -r template xyz - ``` - -## Implement data generator - -Edit dictionary data generator -* Edit file - ```shell - cd ${your_mindspore_home}/example/convert_to_mindrecord - vi xyz/mr_api.py - ``` - - Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented -- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. -- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() - - -Tricky for parallel run -- For imagenet, one directory can be a task. -- For TFRecord with multiple files, each file can be a task. -- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) - - -## Run data generator -* run python script - ```shell - cd ${your_mindspore_home}/example/convert_to_mindrecord - python writer.py --mindrecord_script imagenet [...] - ``` diff --git a/example/convert_to_mindrecord/imagenet/__init__.py b/example/convert_to_mindrecord/imagenet/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/example/convert_to_mindrecord/imagenet/mr_api.py b/example/convert_to_mindrecord/imagenet/mr_api.py deleted file mode 100644 index e569b489b5..0000000000 --- a/example/convert_to_mindrecord/imagenet/mr_api.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -User-defined API for MindRecord writer. -Two API must be implemented, - 1. mindrecord_task_number() - # Return number of parallel tasks. return 1 if no parallel - 2. mindrecord_dict_data(task_id) - # Yield data for one task - # task_id is 0..N-1, if N is return value of mindrecord_task_number() -""" -import argparse -import os -import pickle - -######## mindrecord_schema begin ########## -mindrecord_schema = {"label": {"type": "int64"}, - "data": {"type": "bytes"}, - "file_name": {"type": "string"}} -######## mindrecord_schema end ########## - -######## Frozen code begin ########## -with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: - ARG_LIST = pickle.load(mindrecord_argument_file_handle) -######## Frozen code end ########## - -parser = argparse.ArgumentParser(description='Mind record imagenet example') -parser.add_argument('--label_file', type=str, default="", help='label file') -parser.add_argument('--image_dir', type=str, default="", help='images directory') - -######## Frozen code begin ########## -args = parser.parse_args(ARG_LIST) -print(args) -######## Frozen code end ########## - - -def _user_defined_private_func(): - """ - Internal function for tasks list - - Return: - tasks list - """ - if not os.path.exists(args.label_file): - raise IOError("map file {} not exists".format(args.label_file)) - - label_dict = {} - with open(args.label_file) as file_handle: - line = file_handle.readline() - while line: - labels = line.split(" ") - label_dict[labels[1]] = labels[0] - line = file_handle.readline() - # get all the dir which are n02087046, n02094114, n02109525 - dir_paths = {} - for item in label_dict: - real_path = os.path.join(args.image_dir, label_dict[item]) - if not os.path.isdir(real_path): - print("{} dir is not exist".format(real_path)) - continue - dir_paths[item] = real_path - - if not dir_paths: - print("not valid image dir in {}".format(args.image_dir)) - return {}, {} - - dir_list = [] - for label in dir_paths: - dir_list.append(label) - return dir_list, dir_paths - - -dir_list_global, dir_paths_global = _user_defined_private_func() - -def mindrecord_task_number(): - """ - Get task size. - - Return: - number of tasks - """ - return len(dir_list_global) - - -def mindrecord_dict_data(task_id): - """ - Get data dict. - - Yields: - data (dict): data row which is dict. - """ - - # get the filename, label and image binary as a dict - label = dir_list_global[task_id] - for item in os.listdir(dir_paths_global[label]): - file_name = os.path.join(dir_paths_global[label], item) - if not item.endswith("JPEG") and not item.endswith( - "jpg") and not item.endswith("jpeg"): - print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) - continue - data = {} - data["file_name"] = str(file_name) - data["label"] = int(label) - - # get the image data - image_file = open(file_name, "rb") - image_bytes = image_file.read() - image_file.close() - data["data"] = image_bytes - yield data diff --git a/example/convert_to_mindrecord/run_imagenet.sh b/example/convert_to_mindrecord/run_imagenet.sh deleted file mode 100644 index 11f5dcff75..0000000000 --- a/example/convert_to_mindrecord/run_imagenet.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -rm /tmp/imagenet/mr/* - -python writer.py --mindrecord_script imagenet \ ---mindrecord_file "/tmp/imagenet/mr/m" \ ---mindrecord_partitions 16 \ ---label_file "/tmp/imagenet/label.txt" \ ---image_dir "/tmp/imagenet/jpeg" diff --git a/example/convert_to_mindrecord/run_template.sh b/example/convert_to_mindrecord/run_template.sh deleted file mode 100644 index a4c5142c00..0000000000 --- a/example/convert_to_mindrecord/run_template.sh +++ /dev/null @@ -1,6 +0,0 @@ -#!/bin/bash -rm /tmp/template/* - -python writer.py --mindrecord_script template \ ---mindrecord_file "/tmp/template/m" \ ---mindrecord_partitions 4 diff --git a/example/convert_to_mindrecord/template/__init__.py b/example/convert_to_mindrecord/template/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/example/convert_to_mindrecord/template/mr_api.py b/example/convert_to_mindrecord/template/mr_api.py deleted file mode 100644 index 3f7d7dddf0..0000000000 --- a/example/convert_to_mindrecord/template/mr_api.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -User-defined API for MindRecord writer. -Two API must be implemented, - 1. mindrecord_task_number() - # Return number of parallel tasks. return 1 if no parallel - 2. mindrecord_dict_data(task_id) - # Yield data for one task - # task_id is 0..N-1, if N is return value of mindrecord_task_number() -""" -import argparse -import pickle - -# ## Parse argument - -with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line - ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line -parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line - -# ## Your arguments below -# parser.add_argument(...) - -args = parser.parse_args(ARG_LIST) # Do NOT change this line -print(args) # Do NOT change this line - - -# ## Default mindrecord vars. Comment them unless default value has to be changed. -# mindrecord_index_fields = ['label'] -# mindrecord_header_size = 1 << 24 -# mindrecord_page_size = 1 << 25 - - -# define global vars here if necessary - - -# ####### Your code below ########## -mindrecord_schema = {"label": {"type": "int32"}} - -def mindrecord_task_number(): - """ - Get task size. - - Return: - number of tasks - """ - return 1 - - -def mindrecord_dict_data(task_id): - """ - Get data dict. - - Yields: - data (dict): data row which is dict. - """ - print("task is {}".format(task_id)) - for i in range(256): - data = {} - data['label'] = i - yield data diff --git a/example/convert_to_mindrecord/writer.py b/example/convert_to_mindrecord/writer.py deleted file mode 100644 index 0a9ad5c86a..0000000000 --- a/example/convert_to_mindrecord/writer.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" -######################## write mindrecord example ######################## -Write mindrecord by data dictionary: -python writer.py --mindrecord_script /YourScriptPath ... -""" -import argparse -import os -import pickle -import time -from importlib import import_module -from multiprocessing import Pool - -from mindspore.mindrecord import FileWriter - - -def _exec_task(task_id, parallel_writer=True): - """ - Execute task with specified task id - """ - print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) - imagenet_iter = mindrecord_dict_data(task_id) - batch_size = 2048 - transform_count = 0 - while True: - data_list = [] - try: - for _ in range(batch_size): - data_list.append(imagenet_iter.__next__()) - transform_count += 1 - writer.write_raw_data(data_list, parallel_writer=parallel_writer) - print("transformed {} record...".format(transform_count)) - except StopIteration: - if data_list: - writer.write_raw_data(data_list, parallel_writer=parallel_writer) - print("transformed {} record...".format(transform_count)) - break - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Mind record writer') - parser.add_argument('--mindrecord_script', type=str, default="template", - help='path where script is saved') - - parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", - help='written file name prefix') - - parser.add_argument('--mindrecord_partitions', type=int, default=1, - help='number of written files') - - parser.add_argument('--mindrecord_workers', type=int, default=8, - help='number of parallel workers') - - args = parser.parse_known_args() - - args, other_args = parser.parse_known_args() - - print(args) - print(other_args) - - with open('mr_argument.pickle', 'wb') as file_handle: - pickle.dump(other_args, file_handle) - - try: - mr_api = import_module(args.mindrecord_script + '.mr_api') - except ModuleNotFoundError: - raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) - - num_tasks = mr_api.mindrecord_task_number() - - print("Write mindrecord ...") - - mindrecord_dict_data = mr_api.mindrecord_dict_data - - # get number of files - writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) - - start_time = time.time() - - # set the header size - try: - header_size = mr_api.mindrecord_header_size - writer.set_header_size(header_size) - except AttributeError: - print("Default header size: {}".format(1 << 24)) - - # set the page size - try: - page_size = mr_api.mindrecord_page_size - writer.set_page_size(page_size) - except AttributeError: - print("Default page size: {}".format(1 << 25)) - - # get schema - try: - mindrecord_schema = mr_api.mindrecord_schema - except AttributeError: - raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") - - # create the schema - writer.add_schema(mindrecord_schema, "mindrecord_schema") - - # add the index - try: - index_fields = mr_api.mindrecord_index_fields - writer.add_index(index_fields) - except AttributeError: - print("Default index fields: all simple fields are indexes.") - - writer.open_and_set_header() - - task_list = list(range(num_tasks)) - - # set number of workers - num_workers = args.mindrecord_workers - - if num_tasks < 1: - num_tasks = 1 - - if num_workers > num_tasks: - num_workers = num_tasks - - if num_tasks > 1: - with Pool(num_workers) as p: - p.map(_exec_task, task_list) - else: - _exec_task(0, False) - - ret = writer.commit() - - os.remove("{}".format("mr_argument.pickle")) - - end_time = time.time() - print("--------------------------------------------") - print("END. Total time: {}".format(end_time - start_time)) - print("--------------------------------------------") diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 8718e9b871..338a17ac2d 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -75,9 +75,12 @@ void BindShardWriter(py::module *m) { .def("set_header_size", &ShardWriter::set_header_size) .def("set_page_size", &ShardWriter::set_page_size) .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, - vector> &, bool, bool)) & - ShardWriter::WriteRawData) + .def("write_raw_data", + (MSRStatus(ShardWriter::*)(std::map> &, vector> &, bool)) & + ShardWriter::WriteRawData) + .def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map> &, + std::map> &, bool)) & + ShardWriter::WriteRawData) .def("commit", &ShardWriter::Commit); } diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index 70cfcdb6b7..ca4d3bd66f 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -121,10 +121,6 @@ class ShardHeader { std::vector SerializeHeader(); - MSRStatus PagesToFile(const std::string dump_file_name); - - MSRStatus FileToPages(const std::string dump_file_name); - private: MSRStatus InitializeHeader(const std::vector &headers); diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 78a434fc97..6a22f07700 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -18,7 +18,6 @@ #define MINDRECORD_INCLUDE_SHARD_WRITER_H_ #include -#include #include #include #include @@ -88,7 +87,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); + bool sign = true); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -96,7 +95,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true, bool parallel_writer = false); + bool sign = true); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -104,8 +103,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true, - bool parallel_writer = false); + std::map> &blob_data, bool sign = true); private: /// \brief write shard header data to disk @@ -203,34 +201,7 @@ class ShardWriter { MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, std::map &err_raw_data); - /// \brief Lock writer and save pages info - int LockWriter(bool parallel_writer = false); - - /// \brief Unlock writer and save pages info - MSRStatus UnlockWriter(int fd, bool parallel_writer = false); - - /// \brief Check raw data before writing - MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, - bool sign, int *schema_count, int *row_count); - - /// \brief Get full path from file name - MSRStatus GetFullPathFromFileName(const std::vector &paths); - - /// \brief Open files - MSRStatus OpenDataFiles(bool append); - - /// \brief Remove lock file - MSRStatus RemoveLockFile(); - - /// \brief Remove lock file - MSRStatus InitLockFile(); - private: - const std::string kLockFileSuffix = "_Locker"; - const std::string kPageFileSuffix = "_Pages"; - std::string lock_file_; // lock file for parallel run - std::string pages_file_; // temporary file of pages info for parallel run - int shard_count_; // number of files uint64_t header_size_; // header size uint64_t page_size_; // page size @@ -240,7 +211,7 @@ class ShardWriter { std::vector raw_data_size_; // Raw data size std::vector blob_data_size_; // Blob data size - std::vector file_paths_; // file paths + std::vector file_paths_; // file paths std::vector> file_streams_; // file handles std::shared_ptr shard_header_; // shard headers diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index dc2743cdc7..5a5cd7cbf3 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -520,16 +520,13 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std for (int raw_page_id : raw_page_ids) { auto sql = GenerateRawSQL(fields_); if (sql.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw SQL failed"; return FAILED; } auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); if (data.first != SUCCESS) { - MS_LOG(ERROR) << "Generate raw data failed"; return FAILED; } if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { - MS_LOG(ERROR) << "Execute SQL failed"; return FAILED; } MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index ac95e622c9..864e6697d0 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -40,7 +40,17 @@ ShardWriter::~ShardWriter() { } } -MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + // Get full path from file name for (const auto &path : paths) { if (!CheckIsValidUtf8(path)) { @@ -50,7 +60,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &p char resolved_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Secure func failed"; + MS_LOG(ERROR) << "Securec func failed"; return FAILED; } #if defined(_WIN32) || defined(_WIN64) @@ -72,10 +82,7 @@ MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &p #endif file_paths_.emplace_back(string(resolved_path)); } - return SUCCESS; -} -MSRStatus ShardWriter::OpenDataFiles(bool append) { // Open files for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); @@ -109,67 +116,6 @@ MSRStatus ShardWriter::OpenDataFiles(bool append) { return SUCCESS; } -MSRStatus ShardWriter::RemoveLockFile() { - // Remove temporary file - int ret = std::remove(pages_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove page file."; - } - - ret = std::remove(lock_file_.c_str()); - if (ret == 0) { - MS_LOG(DEBUG) << "Remove lock file."; - } - return SUCCESS; -} - -MSRStatus ShardWriter::InitLockFile() { - if (file_paths_.size() == 0) { - MS_LOG(ERROR) << "File path not initialized."; - return FAILED; - } - - lock_file_ = file_paths_[0] + kLockFileSuffix; - pages_file_ = file_paths_[0] + kPageFileSuffix; - - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove file failed."; - return FAILED; - } - return SUCCESS; -} - -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - - // Get full path from file name - if (GetFullPathFromFileName(paths) == FAILED) { - MS_LOG(ERROR) << "Get full path from file name failed."; - return FAILED; - } - - // Open files - if (OpenDataFiles(append) == FAILED) { - MS_LOG(ERROR) << "Open data files failed."; - return FAILED; - } - - // Init lock file - if (InitLockFile() == FAILED) { - MS_LOG(ERROR) << "Init lock file failed."; - return FAILED; - } - return SUCCESS; -} - MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (!IsLegalFile(path)) { return FAILED; @@ -197,28 +143,11 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { } MSRStatus ShardWriter::Commit() { - // Read pages file - std::ifstream page_file(pages_file_.c_str()); - if (page_file.good()) { - page_file.close(); - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return FAILED; - } - } - if (WriteShardHeader() == FAILED) { MS_LOG(ERROR) << "Write metadata failed"; return FAILED; } MS_LOG(INFO) << "Write metadata successfully."; - - // Remove lock file - if (RemoveLockFile() == FAILED) { - MS_LOG(ERROR) << "Remove lock file failed."; - return FAILED; - } - return SUCCESS; } @@ -526,65 +455,15 @@ void ShardWriter::FillArray(int start, int end, std::map> } } -int ShardWriter::LockWriter(bool parallel_writer) { - if (!parallel_writer) { - return 0; - } - const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); - if (fd >= 0) { - flock(fd, LOCK_EX); - } else { - MS_LOG(ERROR) << "Shard writer failed when locking file"; - return -1; - } - - // Open files - file_streams_.clear(); - for (const auto &file : file_paths_) { - std::shared_ptr fs = std::make_shared(); - fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); - if (fs->fail()) { - MS_LOG(ERROR) << "File could not opened"; - return -1; - } - file_streams_.push_back(fs); - } - - if (shard_header_->FileToPages(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Read pages from file failed"; - return -1; - } - return fd; -} - -MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { - if (!parallel_writer) { - return SUCCESS; - } - - if (shard_header_->PagesToFile(pages_file_) == FAILED) { - MS_LOG(ERROR) << "Write pages to file failed"; - return FAILED; - } - - for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { - file_streams_[i]->close(); - } - - flock(fd, LOCK_UN); - close(fd); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, - std::vector> &blob_data, bool sign, int *schema_count, - int *row_count) { +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign) { // check the free disk size auto st_space = GetDiskSize(file_paths_[0], kFreeSize); if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { MS_LOG(ERROR) << "IO error / there is no free disk to be used"; return FAILED; } + // Add 4-bytes dummy blob data if no any blob fields if (blob_data.size() == 0 && raw_data.size() > 0) { blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); @@ -600,29 +479,10 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map MS_LOG(ERROR) << "Validate raw data failed"; return FAILED; } - *schema_count = std::get<1>(v); - *row_count = std::get<2>(v); - return SUCCESS; -} - -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign, bool parallel_writer) { - // Lock Writer if loading data parallel - int fd = LockWriter(parallel_writer); - if (fd < 0) { - MS_LOG(ERROR) << "Lock writer failed"; - return FAILED; - } // Get the count of schemas and rows - int schema_count = 0; - int row_count = 0; - - // Serialize raw data - if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { - MS_LOG(ERROR) << "Check raw data failed"; - return FAILED; - } + int schema_count = std::get<1>(v); + int row_count = std::get<2>(v); if (row_count == kInt0) { MS_LOG(INFO) << "Raw data size is 0."; @@ -656,17 +516,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d } MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; - if (UnlockWriter(fd, parallel_writer) == FAILED) { - MS_LOG(ERROR) << "Unlock writer failed"; - return FAILED; - } - return SUCCESS; } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign, - bool parallel_writer) { + std::map> &blob_data, bool sign) { std::map> raw_data_json; std::map> blob_data_json; @@ -700,11 +554,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; return FAILED; } - return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); + return WriteRawData(raw_data_json, bin_blob_data, sign); } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign, bool parallel_writer) { + vector> &blob_data, bool sign) { std::map> raw_data_json; (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), [](const std::pair> &pair) { @@ -714,7 +568,7 @@ MSRStatus ShardWriter::WriteRawData(std::map> [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); return std::make_pair(pair.first, std::move(json_raw_data)); }); - return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); + return WriteRawData(raw_data_json, blob_data, sign); } MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 26008e3ca9..57b2e5fa9e 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -677,43 +677,5 @@ std::pair, MSRStatus> ShardHeader::GetStatisticByID( } return std::make_pair(statistics_.at(statistic_id), SUCCESS); } - -MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { - // write header content to file, dump whatever is in the file before - std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); - if (page_out_handle.fail()) { - MS_LOG(ERROR) << "Failed in opening page file"; - return FAILED; - } - - auto pages = SerializePage(); - for (const auto &shard_pages : pages) { - page_out_handle << shard_pages << "\n"; - } - - page_out_handle.close(); - return SUCCESS; -} - -MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { - for (auto &v : pages_) { // clean pages - v.clear(); - } - // attempt to open the file contains the page in json - std::ifstream page_in_handle(dump_file_name.c_str()); - - if (!page_in_handle.good()) { - MS_LOG(INFO) << "No page file exists."; - return SUCCESS; - } - - std::string line; - while (std::getline(page_in_handle, line)) { - ParsePage(json::parse(line)); - } - - page_in_handle.close(); - return SUCCESS; -} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 62bcc2df79..90bca48038 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -200,24 +200,13 @@ class FileWriter: raw_data.pop(i) logger.warning(v) - def open_and_set_header(self): - """ - Open writer and set header - - """ - if not self._writer.is_open: - self._writer.open(self._paths) - if not self._writer.get_shard_header(): - self._writer.set_shard_header(self._header) - - def write_raw_data(self, raw_data, parallel_writer=False): + def write_raw_data(self, raw_data): """ Write raw data and generate sequential pair of MindRecord File and \ validate data based on predefined schema by default. Args: raw_data (list[dict]): List of raw data. - parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). Raises: ParamTypeError: If index field is invalid. @@ -236,7 +225,7 @@ class FileWriter: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') self._verify_based_on_schema(raw_data) - return self._writer.write_raw_data(raw_data, True, parallel_writer) + return self._writer.write_raw_data(raw_data, True) def set_header_size(self, header_size): """ diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0913201861..0ef23d4ce6 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -135,7 +135,7 @@ class ShardWriter: def get_shard_header(self): return self._header - def write_raw_data(self, data, validate=True, parallel_writer=False): + def write_raw_data(self, data, validate=True): """ Write raw data of cv dataset. @@ -145,7 +145,6 @@ class ShardWriter: Args: data (list[dict]): List of raw data. validate (bool, optional): verify data according schema if it equals to True. - parallel_writer (bool, optional): Load data parallel if it equals to True. Returns: MSRStatus, SUCCESS or FAILED. @@ -166,7 +165,7 @@ class ShardWriter: if row_raw: raw_data.append(row_raw) raw_data = {0: raw_data} if raw_data else {} - ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer) + ret = self._writer.write_raw_data(raw_data, blob_data, validate) if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to write dataset.") raise MRMWriteDatasetError From 9e2ec3b8d8d5b63f798465b65d7764882ad9b722 Mon Sep 17 00:00:00 2001 From: leilei_snow Date: Wed, 22 Apr 2020 18:26:08 +0800 Subject: [PATCH 108/142] check the legal value of weight_decay and loss_scale --- mindspore/nn/optim/optimizer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 719e7aa55e..72593e8001 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -88,14 +88,12 @@ class Optimizer(Cell): if isinstance(weight_decay, int): weight_decay = float(weight_decay) - if not isinstance(weight_decay, float): - raise TypeError("weight_decay should be a float number!") + validator.check_float_legal_value('weight_decay', weight_decay, None) if isinstance(loss_scale, int): loss_scale = float(loss_scale) - if not isinstance(loss_scale, float): - raise TypeError("loss_scale should be a float number!") + validator.check_float_legal_value('loss_scale', loss_scale, None) if loss_scale <= 0.0: raise ValueError("Loss scale should be greater than 0, but got {}".format(loss_scale)) From 763aa1067e58ba146b6543bc88de4ccf76b8f912 Mon Sep 17 00:00:00 2001 From: ms_yan <6576637+ms_yan@user.noreply.gitee.com> Date: Wed, 22 Apr 2020 17:08:17 +0800 Subject: [PATCH 109/142] Add Error catch for giving too many input parameters --- mindspore/dataset/engine/validators.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mindspore/dataset/engine/validators.py b/mindspore/dataset/engine/validators.py index 29bce25bd1..dabeb2d424 100644 --- a/mindspore/dataset/engine/validators.py +++ b/mindspore/dataset/engine/validators.py @@ -233,8 +233,13 @@ def make_param_dict(method, args, kwargs): params = sig.parameters keys = list(params.keys()) param_dict = dict() - for name, value in enumerate(args): - param_dict[keys[name]] = value + try: + for name, value in enumerate(args): + param_dict[keys[name]] = value + except IndexError: + raise TypeError("{0}() expected {1} arguments, but {2} were given".format( + method.__name__, len(keys) - 1, len(args) - 1)) + param_dict.update(zip(params.keys(), args)) param_dict.update(kwargs) From 933c6e4a04d5b3a161ebf15b2f8c533977f8b9e2 Mon Sep 17 00:00:00 2001 From: jjfeing Date: Wed, 22 Apr 2020 19:40:55 +0800 Subject: [PATCH 110/142] fix buffer output_desc_index --- mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc index 7a521eb1cd..297c067f54 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_build.cc @@ -700,7 +700,7 @@ std::vector TbeKernelBuild::GetDescOutputIndex(const std::vector &o if (!find_reused) { desc_output_index.emplace_back(idx); } else { - desc_output_index.emplace_back(output_used_nums[idx - 1]); + desc_output_index.emplace_back(desc_output_index[idx - 1]); } reused_num += (output_use_num_item - 1); find_reused = true; @@ -717,7 +717,8 @@ bool TbeKernelBuild::GenFusionComputeOutputJson(const mindspore::CNodePtr &cnode std::vector *output_desc_list) { auto output_size = AnfAlgo::GetOutputTensorNum(cnode); if (AnfAlgo::HasNodeAttr(kAttrOutputUsedNum, cnode)) { - auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); + // wait anther pr: auto output_used_nums = AnfAlgo::GetNodeAttr>(cnode, kAttrOutputUsedNum); + auto output_used_nums = {SizeToInt(AnfAlgo::GetNodeAttr(cnode, kAttrOutputUsedNum))}; MS_LOG(INFO) << "This node's output has been reused, node name: " << cnode->fullname_with_scope(); if (output_used_nums.size() != output_size) { MS_LOG(INFO) << "Fusion error: output tenor num(" << output_size << ")" From b6e77e5178a00059fbd8d0369d9ad92fa1dc8f4d Mon Sep 17 00:00:00 2001 From: liuxiao Date: Tue, 21 Apr 2020 15:25:35 +0800 Subject: [PATCH 111/142] Add ReluV2/ReluGradV2/ConfusionMulGrad for VM --- mindspore/ccsrc/kernel/tbe/tbe_adapter.cc | 1 + mindspore/ops/_grad/grad_nn_ops.py | 12 ++ mindspore/ops/_op_impl/tbe/__init__.py | 3 + .../ops/_op_impl/tbe/confusion_mul_grad.py | 38 +++++++ mindspore/ops/_op_impl/tbe/relu_grad_v2.py | 40 +++++++ mindspore/ops/_op_impl/tbe/relu_v2.py | 40 +++++++ mindspore/ops/operations/__init__.py | 6 +- mindspore/ops/operations/_grad_ops.py | 21 ++++ mindspore/ops/operations/array_ops.py | 2 +- mindspore/ops/operations/nn_ops.py | 105 ++++++++++++++++++ .../davinci/test_tbe_ops/test_relu_v2_grad.py | 53 +++++++++ tests/ut/python/ops/test_ops.py | 19 ++++ 12 files changed, 337 insertions(+), 3 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/confusion_mul_grad.py create mode 100644 mindspore/ops/_op_impl/tbe/relu_grad_v2.py create mode 100644 mindspore/ops/_op_impl/tbe/relu_v2.py create mode 100644 tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 17ac8742f9..44750fab4f 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -33,6 +33,7 @@ static std::map tbe_func_adapter_map = { {"re_lu6", "relu6"}, {"re_lu6_grad", "relu6_grad"}, {"re_lu", "relu"}, + {"re_luv2", "relu_v2"}, {"tensor_add", "add"}, {"reduce_mean", "reduce_mean_d"}, {"reduce_max", "reduce_max_d"}, diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 887c2a7528..e43d3d5d3a 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -227,6 +227,18 @@ def get_bprop_relu6(self): return bprop +@bprop_getters.register(P.ReLUV2) +def get_bprop_relu_v2(self): + """Grad definition for `ReLUV2` operation.""" + input_grad = G.ReluGradV2() + + def bprop(x, out, dout): + mask = out[1] + dx = input_grad(dout[0], mask) + return (dx,) + return bprop + + @bprop_getters.register(P.HSwish) def get_bprop_hswish(self): """Grad definition for `HSwish` operation.""" diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 18ef92ca6e..8030aac5c6 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -33,6 +33,7 @@ from .cast import _cast_tbe from .conv2d import _conv2d_tbe from .conv2d_backprop_filter import _conv2d_backprop_filter_tbe from .conv2d_backprop_input import _conv2d_backprop_input_tbe +from .confusion_mul_grad import _confusion_mul_grad_tbe from .dropout_do_mask import _dropout_do_mask_tbe from .gelu import _gelu_tbe from .gelu_grad import _gelu_grad_tbe @@ -46,6 +47,8 @@ from .relu import _relu_tbe from .relu_grad import _relu_grad_tbe from .relu6 import _relu6_tbe from .relu6_grad import _relu6_grad_tbe +from .relu_v2 import _relu_v2_tbe +from .relu_grad_v2 import _relu_grad_v2_tbe from .softmax_cross_entropy_with_logits import _softmax_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits import _sigmoid_cross_entropy_with_logits_tbe from .sigmoid_cross_entropy_with_logits_grad import _sigmoid_cross_entropy_with_logits_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py b/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py new file mode 100644 index 0000000000..e49d5386f2 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/confusion_mul_grad.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ConfusionMulGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +confusion_mul_grad_op_info = TBERegOp("ConfusionMulGrad") \ + .fusion_type("OPAQUE") \ + .attr("axis", "required", "listInt", "all") \ + .attr("keep_dims", "required", "bool", "all") \ + .input(0, "input0", False, "required", "all") \ + .input(1, "input1", False, "required", "all") \ + .input(2, "input2", False, "required", "all") \ + .output(0, "output0", False, "required", "all") \ + .output(1, "output1", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .get_op_info() + + +@op_info_register(confusion_mul_grad_op_info) +def _confusion_mul_grad_tbe(): + """ConfusionMulGrad TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/relu_grad_v2.py b/mindspore/ops/_op_impl/tbe/relu_grad_v2.py new file mode 100644 index 0000000000..93d7dede62 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/relu_grad_v2.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReluGradV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +relu_grad_v2_op_info = TBERegOp("ReluGradV2") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("relu_grad_v2.so") \ + .compute_cost(10) \ + .kernel_name("relu_grad_v2") \ + .partial_flag(True) \ + .input(0, "gradients", False, "required", "all") \ + .input(1, "mask", False, "rerequired", "all") \ + .output(0, "backprops", True, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.U8_Default, DataType.F16_5HD) \ + .dtype_format(DataType.F32_5HD, DataType.U8_Default, DataType.F32_5HD) \ + .dtype_format(DataType.I32_5HD, DataType.U8_Default, DataType.I32_5HD) \ + .dtype_format(DataType.I8_5HD, DataType.U8_Default, DataType.I8_5HD) \ + .dtype_format(DataType.U8_5HD, DataType.U8_Default, DataType.U8_5HD) \ + .get_op_info() + + +@op_info_register(relu_grad_v2_op_info) +def _relu_grad_v2_tbe(): + """ReluGradV2 TBE register""" + return diff --git a/mindspore/ops/_op_impl/tbe/relu_v2.py b/mindspore/ops/_op_impl/tbe/relu_v2.py new file mode 100644 index 0000000000..c03858c1a7 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/relu_v2.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReluV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +relu_v2_op_info = TBERegOp("ReLUV2") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("relu_v2.so") \ + .compute_cost(10) \ + .kernel_name("relu_v2") \ + .partial_flag(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .output(1, "mask", False, "required", "all") \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.U8_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.U8_Default) \ + .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.U8_Default) \ + .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.U8_Default) \ + .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_Default) \ + .get_op_info() + + +@op_info_register(relu_v2_op_info) +def _relu_v2_tbe(): + """ReluV2 TBE register""" + return diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 80b03a04e1..c75c2031d7 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -58,8 +58,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, GetNext, L2Normalize, LayerNorm, L2Loss, LogSoftmax, MaxPool, ExtractImagePatches, - AvgPool, Conv2DBackpropInput, - MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, + AvgPool, Conv2DBackpropInput, ConfusionMulGrad, + MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, SmoothL1Loss, Softmax, @@ -101,6 +101,7 @@ __all__ = [ 'LogSoftmax', 'SoftmaxCrossEntropyWithLogits', 'ROIAlign', + 'ConfusionMulGrad', 'SparseSoftmaxCrossEntropyWithLogits', 'SGD', 'ApplyMomentum', @@ -138,6 +139,7 @@ __all__ = [ 'Split', 'ReLU', 'ReLU6', + 'ReLUV2', 'Elu', 'Erf', 'Sigmoid', diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 9670ddd86c..c29832dcb7 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -730,6 +730,27 @@ class ReLU6Grad(PrimitiveWithInfer): return x_dtype +class ReluGradV2(PrimitiveWithInfer): + """Performs grad of ReLUV2 operation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['gradients', 'mask'], outputs=['output']) + + def __call__(self, gradients, mask): + raise NotImplementedError + + def infer_shape(self, gradients_shape, mask_shape): + return gradients_shape + + def infer_dtype(self, gradients_dtype, mask_dtype): + args_type = {'gradients': gradients_dtype, 'mask': mask_dtype} + validator.check_args_tensor(args_type) + validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type) + validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,)) + return gradients_dtype + + class EluGrad(PrimitiveWithInfer): """Performs grad of Elu operation.""" diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 2e03676a4a..3b32463c36 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1329,7 +1329,7 @@ class Concat(PrimitiveWithInfer): def _get_pack_shape(x_shape, x_type, axis): """for pack output shape""" - validator.check_type("shape", x_shape, [tuple]) + validator.check_type("shape", x_shape, [tuple, list]) validator.check_integer("len of input_x shape", len(x_shape), 0, Rel.GT) validator.check_subclass("shape0", x_type[0], mstype.tensor) validator.check_integer("len of input_x0 shape", len(x_shape[0]), 0, Rel.GT) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f5f495364b..ba2b6f62fd 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -28,6 +28,7 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register +from ..operations.math_ops import _infer_shape_reduce def _check_positive_int_or_tuple(arg_name, arg_value, prim_name, allow_four=False, ret_four=False): @@ -233,6 +234,62 @@ class ReLU6(PrimitiveWithInfer): return input_x +class ReLUV2(PrimitiveWithInfer): + r""" + Computes ReLU(Rectified Linear Unit) of input tensor element-wise. + + It returns :math:`\max(x,\ 0)` element-wise. + + Inputs: + - **input_x** (Tensor) - The input tensor should be a 4-D tensor. + + Outputs: + - **output** (Tensor) - Has the same type and shape as the `input_x`. + - **mask** (Tensor) - A tensor whose data type must be uint8. + + Examples: + >>> input_x = Tensor(np.array([[[[1, -2], [-3, 4]], [[-5, 6], [7, -8]]]]), mindspore.float32) + >>> relu_v2 = P.ReLUV2() + >>> output = relu_v2(input_x) + ([[[[1., 0.], [0., 4.]], [[0., 6.], [7., 0.]]]], + [[[[1, 0], [2, 0]], [[2, 0], [1, 0]]]]) + """ + @prim_attr_register + def __init__(self): + """init ReLUV2""" + self.init_prim_io_names(inputs=['x'], outputs=['output', 'mask']) + + def __infer__(self, input_x): + input_shape = list(input_x['shape']) + input_dtype = input_x['dtype'] + mask_shape = [] + if len(input_shape) != 4: + raise ValueError("The `input_x` should be a 4-D tensor, " + f"but got a {len(input_shape)}-D tensor whose shape is {input_shape}") + for i in enumerate(input_shape): + if i[0] == 1: + if input_dtype == mstype.uint8 and input_dtype == mstype.int8: + mask_shape.append((input_shape[1] + 31) // 32) + else: + mask_shape.append((input_shape[1] + 15) // 16) + else: + mask_shape.append(i[1]) + if input_dtype == mstype.uint8 and input_dtype == mstype.int8: + mask_shape.append(4) + else: + mask_shape.append(2) + + output_shape = (input_x['shape'], mask_shape) + validator.check_subclass("input_x", input_dtype, mstype.tensor, self.name) + validator.check_tensor_type_same({'input_x': input_dtype}, mstype.number_type, self.name) + mask_dtype = mstype.uint8 + output_dtype = (input_dtype, mask_dtype) + + return {'shape': output_shape, + 'dtype': output_dtype, + 'value': None} + + class Elu(PrimitiveWithInfer): r""" Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise. @@ -2580,3 +2637,51 @@ class ExtractImagePatches(PrimitiveWithInfer): def infer_dtype(self, input_x): validator.check_tensor_type_same({"input_x": input_x}, (mstype.int8, mstype.float16, mstype.float32), self.name) return input_x + + +class ConfusionMulGrad(PrimitiveWithInfer): + """ + `output0` is the result of which input0 dot multily input1. + + `output1` is the result of which input0 dot multily input1, then reducesum it. + + Args: + axis (Union[int, tuple[int], list[int]]): The dimensions to reduce. + Default:(), reduce all dimensions. Only constant value is allowed. + keep_dims (bool): + - If true, keep these reduced dimensions and the length is 1. + - If false, don't keep these dimensions. Default:False. + + Inputs: + - **input_0** (Tensor) - The input Tensor. + - **input_1** (Tensor) - The input Tensor. + - **input_2** (Tensor) - The input Tensor. + + outputs: + - **output_0** (Tensor) - The same shape with `input0`. + - **output_1** (Tensor) + + - If axis is (), and keep_dims is false, the output is a 0-D array representing + the sum of all elements in the input array. + - If axis is int, set as 2, and keep_dims is false, + the shape of output is :math:`(x_1,x_3,...,x_R)`. + - If axis is tuple(int), set as (2,3), and keep_dims is false, + the shape of output is :math:`(x_1,x_4,...x_R)`. + """ + + @prim_attr_register + def __init__(self, axis = (), keep_dims = False): + self.init_prim_io_names(inputs = ["input0", "input1", "input2"], outputs = ["output0", "output1"]) + self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) + self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) + + def infer_shape(self, input0_shape, input1_shape, input2_shape): + outshape0 = input0_shape + outshape1 = _infer_shape_reduce(input1_shape, self.axis_, self.keep_dims_, self.name) + return outshape0, outshape1 + + def infer_dtype(self, input0_dtype, input1_dtype, input2_dtype): + validator.check_subclass("input0_dtype", input0_dtype, mstype.tensor, self.name) + validator.check_subclass("input1_dtype", input1_dtype, mstype.tensor, self.name) + validator.check_subclass("input2_dtype", input2_dtype, mstype.tensor, self.name) + return input0_dtype, input1_dtype diff --git a/tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py b/tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py new file mode 100644 index 0000000000..28bf566c2d --- /dev/null +++ b/tests/st/ops/davinci/test_tbe_ops/test_relu_v2_grad.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +from mindspore.common.api import ms_function +import numpy as np +import mindspore.context as context +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +from mindspore.ops.composite import GradOperation +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(name="get_all", get_all=True) + self.network = network + + @ms_function + def construct(self, input): + return self.grad(self.network)(input) + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.relu_v2 = P.ReLUV2() + + def construct(self, x): + return self.relu_v2(x) + +def test_net(): + x = Tensor(np.ones((2,3,3,4)).astype(np.float32)) + relu_net = Net() + relu_output = relu_net(x) + net = Grad(Net()) + output_grad = net(x) + print(relu_output[0].asnumpy()) + print(relu_output[1].asnumpy()) + print(len(output_grad)) + print(output_grad[0].asnumpy()) diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 442c8bdec6..8b14ea2366 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -582,6 +582,10 @@ test_case_nn_ops = [ 'block': P.ReLU6(), 'desc_inputs': [[1, 3, 4, 4]], 'desc_bprop': [[1, 3, 4, 4]]}), + ('ReLUV2', { + 'block': P.ReLUV2(), + 'desc_inputs': [[1, 3, 4, 4]], + 'desc_bprop': [[1, 3, 4, 4], [1, 3, 4, 4]]}), ('ReLUGrad', { 'block': G.ReluGrad(), 'desc_inputs': [[1, 3, 4, 4], [1, 3, 4, 4]], @@ -1134,6 +1138,21 @@ test_case_other_ops = [ 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), Tensor(np.array([1.2]).astype(np.float32))], 'skip': ['backward']}), + ('ConfusionMulGrad_1', { + 'block': P.ConfusionMulGrad(axis = [0], keep_dims = False), + 'desc_inputs': [[3, 2], [3, 2], [3, 2]], + 'desc_bprop': [[3, 2], [2]], + 'skip': ['backward']}), + ('ConfusionMulGrad_2', { + 'block': P.ConfusionMulGrad(axis = [0], keep_dims = True), + 'desc_inputs': [[3, 2], [3, 2], [3, 2]], + 'desc_bprop': [[3, 2], [1, 2]], + 'skip': ['backward']}), + ('ConfusionMulGrad_3', { + 'block': P.ConfusionMulGrad(axis = (), keep_dims = True), + 'desc_inputs': [[2, 3, 4], [2, 3, 4], [2, 3, 4]], + 'desc_bprop': [[2, 3, 4], [1, 1, 1]], + 'skip': ['backward']}), ('HistogramSummary', { 'block': HistogramSummaryNet(), 'desc_inputs': [Tensor(np.array([1.1]).astype(np.float32)), From abd4239da9ae554a0d536ec5f281209bbde48446 Mon Sep 17 00:00:00 2001 From: xiefangqi Date: Wed, 22 Apr 2020 20:50:37 +0800 Subject: [PATCH 112/142] fix random stuck problem --- mindspore/ccsrc/dataset/CMakeLists.txt | 3 +++ .../ccsrc/dataset/engine/datasetops/map_op.cc | 22 ------------------- .../ccsrc/dataset/engine/datasetops/map_op.h | 4 ---- .../dataset/engine/datasetops/shuffle_op.cc | 7 +++++- mindspore/ccsrc/dataset/util/random.cc | 7 +++++- mindspore/ccsrc/dataset/util/services.cc | 6 ++++- 6 files changed, 20 insertions(+), 29 deletions(-) diff --git a/mindspore/ccsrc/dataset/CMakeLists.txt b/mindspore/ccsrc/dataset/CMakeLists.txt index 879a9346bc..8e9b2664dc 100644 --- a/mindspore/ccsrc/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/CMakeLists.txt @@ -13,6 +13,9 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-format") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") ############################# Options ################################ +if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") + add_definitions(-D _CRT_RAND_S) +endif () if (ENABLE_GPUQUE) add_definitions(-D ENABLE_GPUQUE) message(STATUS "GPU queue is enabled") diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc index 3f8d70b606..b6d603bac9 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.cc @@ -65,9 +65,6 @@ MapOp::MapOp(const std::vector &in_col_names, const std::vectorGetNextBuffer(&buff, 0)); is_eof = buff->eof(); RETURN_IF_NOT_OK(local_queues_[que_id]->Add(std::move(buff))); -#if defined(_WIN32) || defined(_WIN64) - if (is_eof) { - eof_worker_id_ = que_id; - for (int32_t id = 0; id < num_workers_; id++) { - if (id != eof_worker_id_) { - auto eof_buffer = std::make_unique(0, DataBuffer::kDeBFlagEOF); - RETURN_IF_NOT_OK(local_queues_[id]->Add(std::move(eof_buffer))); - } - } - } -#endif que_id = (que_id + 1) % num_workers_; } } @@ -173,14 +159,6 @@ Status MapOp::WorkerEntry(int32_t worker_id) { continue; } else if (in_buffer->eof()) { // Calling base class EofReceived to forward eof buffer. -#if defined(_WIN32) || defined(_Win64) - if (perf_mode_) { - if (eof_worker_id_ == worker_id) { - RETURN_IF_NOT_OK(EofReceived(worker_id)); - } - break; - } -#endif RETURN_IF_NOT_OK(EofReceived(worker_id)); break; } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h index 5e16bc3fed..4c9d27f9c7 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/map_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/map_op.h @@ -193,10 +193,6 @@ class MapOp : public ParallelOp { // cause additional blocking because pop calls to Connector from the threads are synchronized to enforce the order. bool perf_mode_; -#if defined(_WIN32) || defined(_WIN64) - // EOF worker id is only work on Performance mode, to record the worker id of queue which gets EOF - int32_t eof_worker_id_; -#endif // Private function for worker/thread to loop continuously. It comprises the main // logic of MapOp: getting the data from previous Op, validating user specified column names, // applying a list of TensorOps to each of the data, process the results and then diff --git a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc index bdf39b6a39..422c38f2f2 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/shuffle_op.cc @@ -13,6 +13,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#if defined(_WIN32) || defined(_WIN64) +#include +#endif #include #include #include @@ -86,7 +89,9 @@ Status ShuffleOp::SelfReset() { rng_ = std::mt19937_64(shuffle_seed_); } else { #if defined(_WIN32) || defined(_WIN64) - std::random_device random_device; + unsigned int number; + rand_s(&number); + std::mt19937 random_device{static_cast(number)}; #else std::random_device random_device("/dev/urandom"); #endif diff --git a/mindspore/ccsrc/dataset/util/random.cc b/mindspore/ccsrc/dataset/util/random.cc index 2a0762c920..43b3ee4afd 100644 --- a/mindspore/ccsrc/dataset/util/random.cc +++ b/mindspore/ccsrc/dataset/util/random.cc @@ -18,6 +18,9 @@ #include "dataset/util/random.h" +#if defined(_WIN32) || defined(_WIn64) +#include +#endif #include #include #include @@ -33,7 +36,9 @@ uint32_t GetSeed() { uint32_t seed = GlobalContext::config_manager()->seed(); if (seed == std::mt19937::default_seed) { #if defined(_WIN32) || defined(_WIN64) - std::random_device random_device; + unsigned int number; + rand_s(&number); + std::mt19937 random_device{static_cast(number)}; #else std::random_device random_device("/dev/urandom"); #endif diff --git a/mindspore/ccsrc/dataset/util/services.cc b/mindspore/ccsrc/dataset/util/services.cc index ea7b11014c..a2b3f734c2 100644 --- a/mindspore/ccsrc/dataset/util/services.cc +++ b/mindspore/ccsrc/dataset/util/services.cc @@ -18,6 +18,8 @@ #include #if !defined(_WIN32) && !defined(_WIN64) #include +#else +#include #endif #include #include @@ -49,7 +51,9 @@ int Services::GetLWP() { return syscall(SYS_gettid); } std::string Services::GetUniqueID() { const std::string kStr = "abcdefghijklmnopqrstuvwxyz0123456789"; #if defined(_WIN32) || defined(_WIN64) - std::mt19937 gen{std::random_device{}()}; + unsigned int number; + rand_s(&number); + std::mt19937 gen{static_cast(number)}; #else std::mt19937 gen{std::random_device{"/dev/urandom"}()}; #endif From e1e11b841d2e7282c773b68c5d2766487b49155d Mon Sep 17 00:00:00 2001 From: buxue Date: Wed, 22 Apr 2020 16:44:13 +0800 Subject: [PATCH 113/142] fix codedex --- mindspore/ccsrc/pipeline/pipeline.cc | 2 +- mindspore/ccsrc/pipeline/pipeline_ge.cc | 2 +- mindspore/ccsrc/utils/contract.h | 2 ++ mindspore/ccsrc/utils/profile.cc | 28 +++++++++++++++---------- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/pipeline/pipeline.cc b/mindspore/ccsrc/pipeline/pipeline.cc index 5b5cae4044..fca105d13c 100644 --- a/mindspore/ccsrc/pipeline/pipeline.cc +++ b/mindspore/ccsrc/pipeline/pipeline.cc @@ -584,7 +584,7 @@ void ExecutorPy::ProcessVmArg(const py::tuple &args, const std::string &phase, V if (ms_context->backend_policy() == kMsConvert && py::isinstance(arg)) { MS_LOG(EXCEPTION) << "Args[" << i << "] is numpy array, not tensor"; } - (*arg_list).push_back(arg); + arg_list->push_back(arg); } ResourcePtr res = GetResource(phase); diff --git a/mindspore/ccsrc/pipeline/pipeline_ge.cc b/mindspore/ccsrc/pipeline/pipeline_ge.cc index e3b10b73b0..1da85b5699 100644 --- a/mindspore/ccsrc/pipeline/pipeline_ge.cc +++ b/mindspore/ccsrc/pipeline/pipeline_ge.cc @@ -462,7 +462,7 @@ void ProcessGeArg(const std::map &info, const py:: MS_LOG(EXCEPTION) << "Args convert error"; } if (converted->isa()) { - (*inputs).push_back(converted->cast()); + inputs->push_back(converted->cast()); } else { MS_LOG(EXCEPTION) << "Args " << converted->ToString() << " is not tensor"; } diff --git a/mindspore/ccsrc/utils/contract.h b/mindspore/ccsrc/utils/contract.h index fc257b3e24..6ef9928241 100644 --- a/mindspore/ccsrc/utils/contract.h +++ b/mindspore/ccsrc/utils/contract.h @@ -28,6 +28,7 @@ class ContractError : public std::logic_error { public: explicit ContractError(const std::string &msg) : std::logic_error(msg) {} explicit ContractError(const char *msg) : std::logic_error(msg) {} + ~ContractError() override = default; }; struct Signatory { @@ -60,6 +61,7 @@ class Ensures : public EnsuresAccess { } template >> Ensures(const Ensures &other) : value_(other.get()) {} + ~Ensures() = default; T get() const { return value_; } T &get() { return value_; } diff --git a/mindspore/ccsrc/utils/profile.cc b/mindspore/ccsrc/utils/profile.cc index 997cc1b56d..e9e7920e0c 100644 --- a/mindspore/ccsrc/utils/profile.cc +++ b/mindspore/ccsrc/utils/profile.cc @@ -38,26 +38,32 @@ void PrintProfile(std::ostringstream &oss, const TimeInfo &time_info, int indent void PrintTimeInfoMap(std::ostringstream &oss, const TimeInfoMap &dict, int indent = 0, std::map *sums = nullptr, const std::string &prefix = "") { - for (auto iter = dict.begin(); iter != dict.end(); ++iter) { - if (iter->second == nullptr) { + size_t count = 0; + for (const auto &iter : dict) { + count++; + if (iter.second == nullptr) { continue; } // indent by multiples of 4 spaces. - auto name = iter->first.substr(TIME_INFO_PREFIX_NUM_LEN); + if (iter.first.size() < TIME_INFO_PREFIX_NUM_LEN) { + MS_LOG(EXCEPTION) << "In TimeInfoMap, the " << count << "th string key is " << iter.first + << ", but the length is less than " << TIME_INFO_PREFIX_NUM_LEN; + } + auto name = iter.first.substr(TIME_INFO_PREFIX_NUM_LEN); oss << std::setw(indent * 4) << "" - << "[" << name << "]: " << iter->second->time_; - if (iter->second->dict_ != nullptr) { - oss << ", [" << iter->second->dict_->size() << "]"; + << "[" << name << "]: " << iter.second->time_; + if (iter.second->dict_ != nullptr) { + oss << ", [" << iter.second->dict_->size() << "]"; } oss << "\n"; std::string newPrefix = prefix; - if (iter->first.find("Cycle ") == std::string::npos) { - newPrefix = prefix.empty() ? iter->first : prefix + "." + iter->first; + if (iter.first.find("Cycle ") == std::string::npos) { + newPrefix = prefix.empty() ? iter.first : prefix + "." + iter.first; } - PrintProfile(oss, *iter->second, indent + 1, sums, newPrefix); - if (iter->second->dict_ == nullptr) { - (*sums)[newPrefix] += iter->second->time_; + PrintProfile(oss, *iter.second, indent + 1, sums, newPrefix); + if (iter.second->dict_ == nullptr) { + (*sums)[newPrefix] += iter.second->time_; } } } From 8c3d2a0c7c1305413696c5d6e344f8d150ebe6a2 Mon Sep 17 00:00:00 2001 From: biffex Date: Wed, 22 Apr 2020 21:20:32 +0800 Subject: [PATCH 114/142] fix prim hash function --- mindspore/ccsrc/ir/primitive.h | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/mindspore/ccsrc/ir/primitive.h b/mindspore/ccsrc/ir/primitive.h index d16a524f69..73941c1058 100644 --- a/mindspore/ccsrc/ir/primitive.h +++ b/mindspore/ccsrc/ir/primitive.h @@ -152,10 +152,7 @@ struct PrimitiveEqual { }; struct PrimitiveHasher { - std::size_t operator()(PrimitivePtr const &prim) const { - std::size_t hash = std::hash()(prim->name()); - return hash; - } + std::size_t operator()(PrimitivePtr const &prim) const { return prim->Hash(); } }; } // namespace mindspore #endif // MINDSPORE_CCSRC_IR_PRIMITIVE_H_ From 353b3b7944b877f3a35d43de8e48001ba6d0cf17 Mon Sep 17 00:00:00 2001 From: jonwe Date: Wed, 22 Apr 2020 11:09:28 -0400 Subject: [PATCH 115/142] optimize mindrecord writer performance --- example/convert_to_mindrecord/README.md | 46 ++++ .../imagenet/__init__.py | 0 .../convert_to_mindrecord/imagenet/mr_api.py | 122 +++++++++++ example/convert_to_mindrecord/run_imagenet.sh | 8 + example/convert_to_mindrecord/run_template.sh | 6 + .../template/__init__.py | 0 .../convert_to_mindrecord/template/mr_api.py | 73 +++++++ example/convert_to_mindrecord/writer.py | 152 ++++++++++++++ .../ccsrc/mindrecord/common/shard_pybind.cc | 9 +- .../ccsrc/mindrecord/include/shard_header.h | 4 + .../ccsrc/mindrecord/include/shard_writer.h | 37 +++- .../mindrecord/io/shard_index_generator.cc | 3 + mindspore/ccsrc/mindrecord/io/shard_writer.cc | 198 ++++++++++++++++-- .../ccsrc/mindrecord/meta/shard_header.cc | 38 ++++ mindspore/mindrecord/filewriter.py | 15 +- mindspore/mindrecord/shardwriter.py | 5 +- 16 files changed, 681 insertions(+), 35 deletions(-) create mode 100644 example/convert_to_mindrecord/README.md create mode 100644 example/convert_to_mindrecord/imagenet/__init__.py create mode 100644 example/convert_to_mindrecord/imagenet/mr_api.py create mode 100644 example/convert_to_mindrecord/run_imagenet.sh create mode 100644 example/convert_to_mindrecord/run_template.sh create mode 100644 example/convert_to_mindrecord/template/__init__.py create mode 100644 example/convert_to_mindrecord/template/mr_api.py create mode 100644 example/convert_to_mindrecord/writer.py diff --git a/example/convert_to_mindrecord/README.md b/example/convert_to_mindrecord/README.md new file mode 100644 index 0000000000..8d3b25e311 --- /dev/null +++ b/example/convert_to_mindrecord/README.md @@ -0,0 +1,46 @@ +# MindRecord generating guidelines + + + +- [MindRecord generating guidelines](#mindrecord-generating-guidelines) + - [Create work space](#create-work-space) + - [Implement data generator](#implement-data-generator) + - [Run data generator](#run-data-generator) + + + +## Create work space + +Assume the dataset name is 'xyz' +* Create work space from template + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + cp -r template xyz + ``` + +## Implement data generator + +Edit dictionary data generator +* Edit file + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + vi xyz/mr_api.py + ``` + + Two API, 'mindrecord_task_number' and 'mindrecord_dict_data', must be implemented +- 'mindrecord_task_number()' returns number of tasks. Return 1 if data row is generated serially. Return N if generator can be split into N parallel-run tasks. +- 'mindrecord_dict_data(task_id)' yields dictionary data row by row. 'task_id' is 0..N-1, if N is return value of mindrecord_task_number() + + +Tricky for parallel run +- For imagenet, one directory can be a task. +- For TFRecord with multiple files, each file can be a task. +- For TFRecord with 1 file only, it could also be split into N tasks. Task_id=K means: data row is picked only if (count % N == K) + + +## Run data generator +* run python script + ```shell + cd ${your_mindspore_home}/example/convert_to_mindrecord + python writer.py --mindrecord_script imagenet [...] + ``` diff --git a/example/convert_to_mindrecord/imagenet/__init__.py b/example/convert_to_mindrecord/imagenet/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/imagenet/mr_api.py b/example/convert_to_mindrecord/imagenet/mr_api.py new file mode 100644 index 0000000000..e569b489b5 --- /dev/null +++ b/example/convert_to_mindrecord/imagenet/mr_api.py @@ -0,0 +1,122 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import os +import pickle + +######## mindrecord_schema begin ########## +mindrecord_schema = {"label": {"type": "int64"}, + "data": {"type": "bytes"}, + "file_name": {"type": "string"}} +######## mindrecord_schema end ########## + +######## Frozen code begin ########## +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: + ARG_LIST = pickle.load(mindrecord_argument_file_handle) +######## Frozen code end ########## + +parser = argparse.ArgumentParser(description='Mind record imagenet example') +parser.add_argument('--label_file', type=str, default="", help='label file') +parser.add_argument('--image_dir', type=str, default="", help='images directory') + +######## Frozen code begin ########## +args = parser.parse_args(ARG_LIST) +print(args) +######## Frozen code end ########## + + +def _user_defined_private_func(): + """ + Internal function for tasks list + + Return: + tasks list + """ + if not os.path.exists(args.label_file): + raise IOError("map file {} not exists".format(args.label_file)) + + label_dict = {} + with open(args.label_file) as file_handle: + line = file_handle.readline() + while line: + labels = line.split(" ") + label_dict[labels[1]] = labels[0] + line = file_handle.readline() + # get all the dir which are n02087046, n02094114, n02109525 + dir_paths = {} + for item in label_dict: + real_path = os.path.join(args.image_dir, label_dict[item]) + if not os.path.isdir(real_path): + print("{} dir is not exist".format(real_path)) + continue + dir_paths[item] = real_path + + if not dir_paths: + print("not valid image dir in {}".format(args.image_dir)) + return {}, {} + + dir_list = [] + for label in dir_paths: + dir_list.append(label) + return dir_list, dir_paths + + +dir_list_global, dir_paths_global = _user_defined_private_func() + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return len(dir_list_global) + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + + # get the filename, label and image binary as a dict + label = dir_list_global[task_id] + for item in os.listdir(dir_paths_global[label]): + file_name = os.path.join(dir_paths_global[label], item) + if not item.endswith("JPEG") and not item.endswith( + "jpg") and not item.endswith("jpeg"): + print("{} file is not suffix with JPEG/jpg, skip it.".format(file_name)) + continue + data = {} + data["file_name"] = str(file_name) + data["label"] = int(label) + + # get the image data + image_file = open(file_name, "rb") + image_bytes = image_file.read() + image_file.close() + data["data"] = image_bytes + yield data diff --git a/example/convert_to_mindrecord/run_imagenet.sh b/example/convert_to_mindrecord/run_imagenet.sh new file mode 100644 index 0000000000..11f5dcff75 --- /dev/null +++ b/example/convert_to_mindrecord/run_imagenet.sh @@ -0,0 +1,8 @@ +#!/bin/bash +rm /tmp/imagenet/mr/* + +python writer.py --mindrecord_script imagenet \ +--mindrecord_file "/tmp/imagenet/mr/m" \ +--mindrecord_partitions 16 \ +--label_file "/tmp/imagenet/label.txt" \ +--image_dir "/tmp/imagenet/jpeg" diff --git a/example/convert_to_mindrecord/run_template.sh b/example/convert_to_mindrecord/run_template.sh new file mode 100644 index 0000000000..a4c5142c00 --- /dev/null +++ b/example/convert_to_mindrecord/run_template.sh @@ -0,0 +1,6 @@ +#!/bin/bash +rm /tmp/template/* + +python writer.py --mindrecord_script template \ +--mindrecord_file "/tmp/template/m" \ +--mindrecord_partitions 4 diff --git a/example/convert_to_mindrecord/template/__init__.py b/example/convert_to_mindrecord/template/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/convert_to_mindrecord/template/mr_api.py b/example/convert_to_mindrecord/template/mr_api.py new file mode 100644 index 0000000000..3f7d7dddf0 --- /dev/null +++ b/example/convert_to_mindrecord/template/mr_api.py @@ -0,0 +1,73 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +User-defined API for MindRecord writer. +Two API must be implemented, + 1. mindrecord_task_number() + # Return number of parallel tasks. return 1 if no parallel + 2. mindrecord_dict_data(task_id) + # Yield data for one task + # task_id is 0..N-1, if N is return value of mindrecord_task_number() +""" +import argparse +import pickle + +# ## Parse argument + +with open('mr_argument.pickle', 'rb') as mindrecord_argument_file_handle: # Do NOT change this line + ARG_LIST = pickle.load(mindrecord_argument_file_handle) # Do NOT change this line +parser = argparse.ArgumentParser(description='Mind record api template') # Do NOT change this line + +# ## Your arguments below +# parser.add_argument(...) + +args = parser.parse_args(ARG_LIST) # Do NOT change this line +print(args) # Do NOT change this line + + +# ## Default mindrecord vars. Comment them unless default value has to be changed. +# mindrecord_index_fields = ['label'] +# mindrecord_header_size = 1 << 24 +# mindrecord_page_size = 1 << 25 + + +# define global vars here if necessary + + +# ####### Your code below ########## +mindrecord_schema = {"label": {"type": "int32"}} + +def mindrecord_task_number(): + """ + Get task size. + + Return: + number of tasks + """ + return 1 + + +def mindrecord_dict_data(task_id): + """ + Get data dict. + + Yields: + data (dict): data row which is dict. + """ + print("task is {}".format(task_id)) + for i in range(256): + data = {} + data['label'] = i + yield data diff --git a/example/convert_to_mindrecord/writer.py b/example/convert_to_mindrecord/writer.py new file mode 100644 index 0000000000..d34f1fb1c7 --- /dev/null +++ b/example/convert_to_mindrecord/writer.py @@ -0,0 +1,152 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +######################## write mindrecord example ######################## +Write mindrecord by data dictionary: +python writer.py --mindrecord_script /YourScriptPath ... +""" +import argparse +import os +import pickle +import time +from importlib import import_module +from multiprocessing import Pool + +from mindspore.mindrecord import FileWriter + + +def _exec_task(task_id, parallel_writer=True): + """ + Execute task with specified task id + """ + print("exec task {}, parallel: {} ...".format(task_id, parallel_writer)) + imagenet_iter = mindrecord_dict_data(task_id) + batch_size = 2048 + transform_count = 0 + while True: + data_list = [] + try: + for _ in range(batch_size): + data_list.append(imagenet_iter.__next__()) + transform_count += 1 + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + except StopIteration: + if data_list: + writer.write_raw_data(data_list, parallel_writer=parallel_writer) + print("transformed {} record...".format(transform_count)) + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='Mind record writer') + parser.add_argument('--mindrecord_script', type=str, default="template", + help='path where script is saved') + + parser.add_argument('--mindrecord_file', type=str, default="/tmp/mindrecord", + help='written file name prefix') + + parser.add_argument('--mindrecord_partitions', type=int, default=1, + help='number of written files') + + parser.add_argument('--mindrecord_workers', type=int, default=8, + help='number of parallel workers') + + args = parser.parse_known_args() + + args, other_args = parser.parse_known_args() + + print(args) + print(other_args) + + with open('mr_argument.pickle', 'wb') as file_handle: + pickle.dump(other_args, file_handle) + + try: + mr_api = import_module(args.mindrecord_script + '.mr_api') + except ModuleNotFoundError: + raise RuntimeError("Unknown module path: {}".format(args.mindrecord_script + '.mr_api')) + + num_tasks = mr_api.mindrecord_task_number() + + print("Write mindrecord ...") + + mindrecord_dict_data = mr_api.mindrecord_dict_data + + # get number of files + writer = FileWriter(args.mindrecord_file, args.mindrecord_partitions) + + start_time = time.time() + + # set the header size + try: + header_size = mr_api.mindrecord_header_size + writer.set_header_size(header_size) + except AttributeError: + print("Default header size: {}".format(1 << 24)) + + # set the page size + try: + page_size = mr_api.mindrecord_page_size + writer.set_page_size(page_size) + except AttributeError: + print("Default page size: {}".format(1 << 25)) + + # get schema + try: + mindrecord_schema = mr_api.mindrecord_schema + except AttributeError: + raise RuntimeError("mindrecord_schema is not defined in mr_api.py.") + + # create the schema + writer.add_schema(mindrecord_schema, "mindrecord_schema") + + # add the index + try: + index_fields = mr_api.mindrecord_index_fields + writer.add_index(index_fields) + except AttributeError: + print("Default index fields: all simple fields are indexes.") + + writer.open_and_set_header() + + task_list = list(range(num_tasks)) + + # set number of workers + num_workers = args.mindrecord_workers + + if num_tasks < 1: + num_tasks = 1 + + if num_workers > num_tasks: + num_workers = num_tasks + + if os.name == 'nt': + for window_task_id in task_list: + _exec_task(window_task_id, False) + elif num_tasks > 1: + with Pool(num_workers) as p: + p.map(_exec_task, task_list) + else: + _exec_task(0, False) + + ret = writer.commit() + + os.remove("{}".format("mr_argument.pickle")) + + end_time = time.time() + print("--------------------------------------------") + print("END. Total time: {}".format(end_time - start_time)) + print("--------------------------------------------") diff --git a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc index 338a17ac2d..8718e9b871 100644 --- a/mindspore/ccsrc/mindrecord/common/shard_pybind.cc +++ b/mindspore/ccsrc/mindrecord/common/shard_pybind.cc @@ -75,12 +75,9 @@ void BindShardWriter(py::module *m) { .def("set_header_size", &ShardWriter::set_header_size) .def("set_page_size", &ShardWriter::set_page_size) .def("set_shard_header", &ShardWriter::SetShardHeader) - .def("write_raw_data", - (MSRStatus(ShardWriter::*)(std::map> &, vector> &, bool)) & - ShardWriter::WriteRawData) - .def("write_raw_nlp_data", (MSRStatus(ShardWriter::*)(std::map> &, - std::map> &, bool)) & - ShardWriter::WriteRawData) + .def("write_raw_data", (MSRStatus(ShardWriter::*)(std::map> &, + vector> &, bool, bool)) & + ShardWriter::WriteRawData) .def("commit", &ShardWriter::Commit); } diff --git a/mindspore/ccsrc/mindrecord/include/shard_header.h b/mindspore/ccsrc/mindrecord/include/shard_header.h index ca4d3bd66f..70cfcdb6b7 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_header.h +++ b/mindspore/ccsrc/mindrecord/include/shard_header.h @@ -121,6 +121,10 @@ class ShardHeader { std::vector SerializeHeader(); + MSRStatus PagesToFile(const std::string dump_file_name); + + MSRStatus FileToPages(const std::string dump_file_name); + private: MSRStatus InitializeHeader(const std::vector &headers); diff --git a/mindspore/ccsrc/mindrecord/include/shard_writer.h b/mindspore/ccsrc/mindrecord/include/shard_writer.h index 6a22f07700..78a434fc97 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_writer.h +++ b/mindspore/ccsrc/mindrecord/include/shard_writer.h @@ -18,6 +18,7 @@ #define MINDRECORD_INCLUDE_SHARD_WRITER_H_ #include +#include #include #include #include @@ -87,7 +88,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -95,7 +96,7 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, vector> &blob_data, - bool sign = true); + bool sign = true, bool parallel_writer = false); /// \brief write raw data by group size for call from python /// \param[in] raw_data the vector of raw json data, python-handle format @@ -103,7 +104,8 @@ class ShardWriter { /// \param[in] sign validate data or not /// \return MSRStatus the status of MSRStatus to judge if write successfully MSRStatus WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign = true); + std::map> &blob_data, bool sign = true, + bool parallel_writer = false); private: /// \brief write shard header data to disk @@ -201,7 +203,34 @@ class ShardWriter { MSRStatus CheckDataTypeAndValue(const std::string &key, const json &value, const json &data, const int &i, std::map &err_raw_data); + /// \brief Lock writer and save pages info + int LockWriter(bool parallel_writer = false); + + /// \brief Unlock writer and save pages info + MSRStatus UnlockWriter(int fd, bool parallel_writer = false); + + /// \brief Check raw data before writing + MSRStatus WriteRawDataPreCheck(std::map> &raw_data, vector> &blob_data, + bool sign, int *schema_count, int *row_count); + + /// \brief Get full path from file name + MSRStatus GetFullPathFromFileName(const std::vector &paths); + + /// \brief Open files + MSRStatus OpenDataFiles(bool append); + + /// \brief Remove lock file + MSRStatus RemoveLockFile(); + + /// \brief Remove lock file + MSRStatus InitLockFile(); + private: + const std::string kLockFileSuffix = "_Locker"; + const std::string kPageFileSuffix = "_Pages"; + std::string lock_file_; // lock file for parallel run + std::string pages_file_; // temporary file of pages info for parallel run + int shard_count_; // number of files uint64_t header_size_; // header size uint64_t page_size_; // page size @@ -211,7 +240,7 @@ class ShardWriter { std::vector raw_data_size_; // Raw data size std::vector blob_data_size_; // Blob data size - std::vector file_paths_; // file paths + std::vector file_paths_; // file paths std::vector> file_streams_; // file handles std::shared_ptr shard_header_; // shard headers diff --git a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc index 5a5cd7cbf3..dc2743cdc7 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_index_generator.cc @@ -520,13 +520,16 @@ MSRStatus ShardIndexGenerator::ExecuteTransaction(const int &shard_no, const std for (int raw_page_id : raw_page_ids) { auto sql = GenerateRawSQL(fields_); if (sql.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw SQL failed"; return FAILED; } auto data = GenerateRowData(shard_no, blob_id_to_page_id, raw_page_id, in); if (data.first != SUCCESS) { + MS_LOG(ERROR) << "Generate raw data failed"; return FAILED; } if (BindParameterExecuteSQL(db.second, sql.second, data.second) == FAILED) { + MS_LOG(ERROR) << "Execute SQL failed"; return FAILED; } MS_LOG(INFO) << "Insert " << data.second.size() << " rows to index db."; diff --git a/mindspore/ccsrc/mindrecord/io/shard_writer.cc b/mindspore/ccsrc/mindrecord/io/shard_writer.cc index 864e6697d0..2fb5db5503 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_writer.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_writer.cc @@ -40,17 +40,7 @@ ShardWriter::~ShardWriter() { } } -MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { - shard_count_ = paths.size(); - if (shard_count_ > kMaxShardCount || shard_count_ == 0) { - MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; - return FAILED; - } - if (schema_count_ > kMaxSchemaCount) { - MS_LOG(ERROR) << "The schema Count greater than max value."; - return FAILED; - } - +MSRStatus ShardWriter::GetFullPathFromFileName(const std::vector &paths) { // Get full path from file name for (const auto &path : paths) { if (!CheckIsValidUtf8(path)) { @@ -60,7 +50,7 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) char resolved_path[PATH_MAX] = {0}; char buf[PATH_MAX] = {0}; if (strncpy_s(buf, PATH_MAX, common::SafeCStr(path), path.length()) != EOK) { - MS_LOG(ERROR) << "Securec func failed"; + MS_LOG(ERROR) << "Secure func failed"; return FAILED; } #if defined(_WIN32) || defined(_WIN64) @@ -82,7 +72,10 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) #endif file_paths_.emplace_back(string(resolved_path)); } + return SUCCESS; +} +MSRStatus ShardWriter::OpenDataFiles(bool append) { // Open files for (const auto &file : file_paths_) { std::shared_ptr fs = std::make_shared(); @@ -116,6 +109,67 @@ MSRStatus ShardWriter::Open(const std::vector &paths, bool append) return SUCCESS; } +MSRStatus ShardWriter::RemoveLockFile() { + // Remove temporary file + int ret = std::remove(pages_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove page file."; + } + + ret = std::remove(lock_file_.c_str()); + if (ret == 0) { + MS_LOG(DEBUG) << "Remove lock file."; + } + return SUCCESS; +} + +MSRStatus ShardWriter::InitLockFile() { + if (file_paths_.size() == 0) { + MS_LOG(ERROR) << "File path not initialized."; + return FAILED; + } + + lock_file_ = file_paths_[0] + kLockFileSuffix; + pages_file_ = file_paths_[0] + kPageFileSuffix; + + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove file failed."; + return FAILED; + } + return SUCCESS; +} + +MSRStatus ShardWriter::Open(const std::vector &paths, bool append) { + shard_count_ = paths.size(); + if (shard_count_ > kMaxShardCount || shard_count_ == 0) { + MS_LOG(ERROR) << "The Shard Count greater than max value or equal to 0."; + return FAILED; + } + if (schema_count_ > kMaxSchemaCount) { + MS_LOG(ERROR) << "The schema Count greater than max value."; + return FAILED; + } + + // Get full path from file name + if (GetFullPathFromFileName(paths) == FAILED) { + MS_LOG(ERROR) << "Get full path from file name failed."; + return FAILED; + } + + // Open files + if (OpenDataFiles(append) == FAILED) { + MS_LOG(ERROR) << "Open data files failed."; + return FAILED; + } + + // Init lock file + if (InitLockFile() == FAILED) { + MS_LOG(ERROR) << "Init lock file failed."; + return FAILED; + } + return SUCCESS; +} + MSRStatus ShardWriter::OpenForAppend(const std::string &path) { if (!IsLegalFile(path)) { return FAILED; @@ -143,11 +197,28 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) { } MSRStatus ShardWriter::Commit() { + // Read pages file + std::ifstream page_file(pages_file_.c_str()); + if (page_file.good()) { + page_file.close(); + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return FAILED; + } + } + if (WriteShardHeader() == FAILED) { MS_LOG(ERROR) << "Write metadata failed"; return FAILED; } MS_LOG(INFO) << "Write metadata successfully."; + + // Remove lock file + if (RemoveLockFile() == FAILED) { + MS_LOG(ERROR) << "Remove lock file failed."; + return FAILED; + } + return SUCCESS; } @@ -455,15 +526,75 @@ void ShardWriter::FillArray(int start, int end, std::map> } } -MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::vector> &blob_data, bool sign) { +int ShardWriter::LockWriter(bool parallel_writer) { + if (!parallel_writer) { + return 0; + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Lock file done by python."; + const int fd = 0; +#else + const int fd = open(lock_file_.c_str(), O_WRONLY | O_CREAT, 0666); + if (fd >= 0) { + flock(fd, LOCK_EX); + } else { + MS_LOG(ERROR) << "Shard writer failed when locking file"; + return -1; + } +#endif + + // Open files + file_streams_.clear(); + for (const auto &file : file_paths_) { + std::shared_ptr fs = std::make_shared(); + fs->open(common::SafeCStr(file), std::ios::in | std::ios::out | std::ios::binary); + if (fs->fail()) { + MS_LOG(ERROR) << "File could not opened"; + return -1; + } + file_streams_.push_back(fs); + } + + if (shard_header_->FileToPages(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Read pages from file failed"; + return -1; + } + return fd; +} + +MSRStatus ShardWriter::UnlockWriter(int fd, bool parallel_writer) { + if (!parallel_writer) { + return SUCCESS; + } + + if (shard_header_->PagesToFile(pages_file_) == FAILED) { + MS_LOG(ERROR) << "Write pages to file failed"; + return FAILED; + } + + for (int i = static_cast(file_streams_.size()) - 1; i >= 0; i--) { + file_streams_[i]->close(); + } + +#if defined(_WIN32) || defined(_WIN64) + MS_LOG(DEBUG) << "Unlock file done by python."; +#else + flock(fd, LOCK_UN); + close(fd); +#endif + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawDataPreCheck(std::map> &raw_data, + std::vector> &blob_data, bool sign, int *schema_count, + int *row_count) { // check the free disk size auto st_space = GetDiskSize(file_paths_[0], kFreeSize); if (st_space.first != SUCCESS || st_space.second < kMinFreeDiskSize) { MS_LOG(ERROR) << "IO error / there is no free disk to be used"; return FAILED; } - // Add 4-bytes dummy blob data if no any blob fields if (blob_data.size() == 0 && raw_data.size() > 0) { blob_data = std::vector>(raw_data[0].size(), std::vector(kUnsignedInt4, 0)); @@ -479,10 +610,29 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d MS_LOG(ERROR) << "Validate raw data failed"; return FAILED; } + *schema_count = std::get<1>(v); + *row_count = std::get<2>(v); + return SUCCESS; +} + +MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, + std::vector> &blob_data, bool sign, bool parallel_writer) { + // Lock Writer if loading data parallel + int fd = LockWriter(parallel_writer); + if (fd < 0) { + MS_LOG(ERROR) << "Lock writer failed"; + return FAILED; + } // Get the count of schemas and rows - int schema_count = std::get<1>(v); - int row_count = std::get<2>(v); + int schema_count = 0; + int row_count = 0; + + // Serialize raw data + if (WriteRawDataPreCheck(raw_data, blob_data, sign, &schema_count, &row_count) == FAILED) { + MS_LOG(ERROR) << "Check raw data failed"; + return FAILED; + } if (row_count == kInt0) { MS_LOG(INFO) << "Raw data size is 0."; @@ -516,11 +666,17 @@ MSRStatus ShardWriter::WriteRawData(std::map> &raw_d } MS_LOG(INFO) << "Write " << bin_raw_data.size() << " records successfully."; + if (UnlockWriter(fd, parallel_writer) == FAILED) { + MS_LOG(ERROR) << "Unlock writer failed"; + return FAILED; + } + return SUCCESS; } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - std::map> &blob_data, bool sign) { + std::map> &blob_data, bool sign, + bool parallel_writer) { std::map> raw_data_json; std::map> blob_data_json; @@ -554,11 +710,11 @@ MSRStatus ShardWriter::WriteRawData(std::map> MS_LOG(ERROR) << "Serialize raw data failed in write raw data"; return FAILED; } - return WriteRawData(raw_data_json, bin_blob_data, sign); + return WriteRawData(raw_data_json, bin_blob_data, sign, parallel_writer); } MSRStatus ShardWriter::WriteRawData(std::map> &raw_data, - vector> &blob_data, bool sign) { + vector> &blob_data, bool sign, bool parallel_writer) { std::map> raw_data_json; (void)std::transform(raw_data.begin(), raw_data.end(), std::inserter(raw_data_json, raw_data_json.end()), [](const std::pair> &pair) { @@ -568,7 +724,7 @@ MSRStatus ShardWriter::WriteRawData(std::map> [](const py::handle &obj) { return nlohmann::detail::ToJsonImpl(obj); }); return std::make_pair(pair.first, std::move(json_raw_data)); }); - return WriteRawData(raw_data_json, blob_data, sign); + return WriteRawData(raw_data_json, blob_data, sign, parallel_writer); } MSRStatus ShardWriter::ParallelWriteData(const std::vector> &blob_data, diff --git a/mindspore/ccsrc/mindrecord/meta/shard_header.cc b/mindspore/ccsrc/mindrecord/meta/shard_header.cc index 57b2e5fa9e..26008e3ca9 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_header.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_header.cc @@ -677,5 +677,43 @@ std::pair, MSRStatus> ShardHeader::GetStatisticByID( } return std::make_pair(statistics_.at(statistic_id), SUCCESS); } + +MSRStatus ShardHeader::PagesToFile(const std::string dump_file_name) { + // write header content to file, dump whatever is in the file before + std::ofstream page_out_handle(dump_file_name.c_str(), std::ios_base::trunc | std::ios_base::out); + if (page_out_handle.fail()) { + MS_LOG(ERROR) << "Failed in opening page file"; + return FAILED; + } + + auto pages = SerializePage(); + for (const auto &shard_pages : pages) { + page_out_handle << shard_pages << "\n"; + } + + page_out_handle.close(); + return SUCCESS; +} + +MSRStatus ShardHeader::FileToPages(const std::string dump_file_name) { + for (auto &v : pages_) { // clean pages + v.clear(); + } + // attempt to open the file contains the page in json + std::ifstream page_in_handle(dump_file_name.c_str()); + + if (!page_in_handle.good()) { + MS_LOG(INFO) << "No page file exists."; + return SUCCESS; + } + + std::string line; + while (std::getline(page_in_handle, line)) { + ParsePage(json::parse(line)); + } + + page_in_handle.close(); + return SUCCESS; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/mindrecord/filewriter.py b/mindspore/mindrecord/filewriter.py index 90bca48038..62bcc2df79 100644 --- a/mindspore/mindrecord/filewriter.py +++ b/mindspore/mindrecord/filewriter.py @@ -200,13 +200,24 @@ class FileWriter: raw_data.pop(i) logger.warning(v) - def write_raw_data(self, raw_data): + def open_and_set_header(self): + """ + Open writer and set header + + """ + if not self._writer.is_open: + self._writer.open(self._paths) + if not self._writer.get_shard_header(): + self._writer.set_shard_header(self._header) + + def write_raw_data(self, raw_data, parallel_writer=False): """ Write raw data and generate sequential pair of MindRecord File and \ validate data based on predefined schema by default. Args: raw_data (list[dict]): List of raw data. + parallel_writer (bool, optional): Load data parallel if it equals to True (default=False). Raises: ParamTypeError: If index field is invalid. @@ -225,7 +236,7 @@ class FileWriter: if not isinstance(each_raw, dict): raise ParamTypeError('raw_data item', 'dict') self._verify_based_on_schema(raw_data) - return self._writer.write_raw_data(raw_data, True) + return self._writer.write_raw_data(raw_data, True, parallel_writer) def set_header_size(self, header_size): """ diff --git a/mindspore/mindrecord/shardwriter.py b/mindspore/mindrecord/shardwriter.py index 0ef23d4ce6..0913201861 100644 --- a/mindspore/mindrecord/shardwriter.py +++ b/mindspore/mindrecord/shardwriter.py @@ -135,7 +135,7 @@ class ShardWriter: def get_shard_header(self): return self._header - def write_raw_data(self, data, validate=True): + def write_raw_data(self, data, validate=True, parallel_writer=False): """ Write raw data of cv dataset. @@ -145,6 +145,7 @@ class ShardWriter: Args: data (list[dict]): List of raw data. validate (bool, optional): verify data according schema if it equals to True. + parallel_writer (bool, optional): Load data parallel if it equals to True. Returns: MSRStatus, SUCCESS or FAILED. @@ -165,7 +166,7 @@ class ShardWriter: if row_raw: raw_data.append(row_raw) raw_data = {0: raw_data} if raw_data else {} - ret = self._writer.write_raw_data(raw_data, blob_data, validate) + ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer) if ret != ms.MSRStatus.SUCCESS: logger.error("Failed to write dataset.") raise MRMWriteDatasetError From 6acae622dc423d77cbf9207664d2fd9a49da86a0 Mon Sep 17 00:00:00 2001 From: Alexey Shevlyakov Date: Mon, 20 Apr 2020 15:22:13 -0400 Subject: [PATCH 116/142] fix random seed behaviour --- .../ccsrc/dataset/api/python_bindings.cc | 13 -- .../dataset/kernels/image/CMakeLists.txt | 4 +- .../ccsrc/dataset/kernels/image/cut_out_op.cc | 7 +- .../ccsrc/dataset/kernels/image/cut_out_op.h | 1 + .../image/distort_bounding_box_crop_op.cc | 117 ------------------ .../image/distort_bounding_box_crop_op.h | 72 ----------- .../dataset/kernels/image/image_utils.cc | 80 ++---------- .../ccsrc/dataset/kernels/image/image_utils.h | 10 +- 8 files changed, 15 insertions(+), 289 deletions(-) delete mode 100644 mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc delete mode 100644 mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 76e971fd2b..ea2e8352da 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -24,7 +24,6 @@ #endif #include "dataset/kernels/image/cut_out_op.h" #include "dataset/kernels/image/decode_op.h" -#include "dataset/kernels/image/distort_bounding_box_crop_op.h" #include "dataset/kernels/image/hwc_to_chw_op.h" #include "dataset/kernels/image/image_utils.h" #include "dataset/kernels/image/normalize_op.h" @@ -369,18 +368,6 @@ void bindTensorOps3(py::module *m) { } void bindTensorOps4(py::module *m) { - (void)py::class_>( - *m, "DistortBoundingBoxCropOp", - "Tensor operator to crop an image randomly as long as the cropped image has sufficient " - "overlap with any one bounding box associated with original image" - "Takes aspect ratio of the generated crop box, the intersection ratio of crop box and bounding box," - "crop ratio lower and upper bounds" - "Optional parameters: number of attempts for crop, number of attempts of crop box generation") - .def(py::init(), py::arg("aspect_ratio"), py::arg("intersect_ratio"), - py::arg("crop_ratio_lower_bound"), py::arg("crop_ratio_upper_bound"), - py::arg("max_attempts") = DistortBoundingBoxCropOp::kDefMaxAttempts, - py::arg("box_gen_attempts") = DistortBoundingBoxCropOp::kDefBoxGenAttempts); - (void)py::class_>( *m, "TypeCastOp", "Tensor operator to type cast data to a specified type.") .def(py::init(), py::arg("data_type")) diff --git a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt index 33e681337c..43b68d8933 100644 --- a/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/dataset/kernels/image/CMakeLists.txt @@ -3,7 +3,6 @@ if (WIN32) center_crop_op.cc cut_out_op.cc decode_op.cc - distort_bounding_box_crop_op.cc hwc_to_chw_op.cc image_utils.cc normalize_op.cc @@ -27,7 +26,6 @@ else() change_mode_op.cc cut_out_op.cc decode_op.cc - distort_bounding_box_crop_op.cc hwc_to_chw_op.cc image_utils.cc normalize_op.cc @@ -45,4 +43,4 @@ else() resize_op.cc uniform_aug_op.cc ) -endif() +endif() \ No newline at end of file diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc index 9327d785fc..74d9df5d6b 100644 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc +++ b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.cc @@ -33,7 +33,8 @@ const uint8_t CutOutOp::kDefFillB = 0; // constructor CutOutOp::CutOutOp(int32_t box_height, int32_t box_width, int32_t num_patches, bool random_color, uint8_t fill_r, uint8_t fill_g, uint8_t fill_b) - : box_height_(box_height), + : rnd_(GetSeed()), + box_height_(box_height), box_width_(box_width), num_patches_(num_patches), random_color_(random_color), @@ -46,8 +47,8 @@ Status CutOutOp::Compute(const std::shared_ptr &input, std::shared_ptr inputCV = CVTensor::AsCVTensor(input); // cut out will clip the erasing area if the box is near the edge of the image and the boxes are black - RETURN_IF_NOT_OK( - Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, fill_r_, fill_g_, fill_b_)); + RETURN_IF_NOT_OK(Erase(inputCV, output, box_height_, box_width_, num_patches_, false, random_color_, &rnd_, fill_r_, + fill_g_, fill_b_)); return Status::OK(); } } // namespace dataset diff --git a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h index 9a76572a54..2198f23e44 100644 --- a/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h +++ b/mindspore/ccsrc/dataset/kernels/image/cut_out_op.h @@ -62,6 +62,7 @@ class CutOutOp : public TensorOp { Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; private: + std::mt19937 rnd_; int32_t box_height_; int32_t box_width_; int32_t num_patches_; diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc deleted file mode 100644 index a28f2bb6fd..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.cc +++ /dev/null @@ -1,117 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "dataset/kernels/image/distort_bounding_box_crop_op.h" -#include -#include "dataset/core/cv_tensor.h" -#include "dataset/kernels/image/image_utils.h" -#include "dataset/util/random.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -const int32_t DistortBoundingBoxCropOp::kDefMaxAttempts = 100; -const int32_t DistortBoundingBoxCropOp::kDefBoxGenAttempts = 10; - -DistortBoundingBoxCropOp::DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, - float crop_ratio_ub, int32_t max_attempts, int32_t box_gen_attempts) - : max_attempts_(max_attempts), - box_gen_attempts_(box_gen_attempts), - aspect_ratio_(aspect_ratio), - intersect_ratio_(intersect_ratio), - crop_ratio_lb_(crop_ratio_lb), - crop_ratio_ub_(crop_ratio_ub) { - seed_ = GetSeed(); - rnd_.seed(seed_); -} - -Status DistortBoundingBoxCropOp::Compute(const std::vector> &input, - std::vector> *output) { - IO_CHECK_VECTOR(input, output); - if (input.size() != NumInput()) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Number of inputs is not 5"); - - CHECK_FAIL_RETURN_UNEXPECTED(input[1]->shape().Size() >= 1, "The shape of the second tensor is abnormal"); - int64_t num_boxes = 0; - for (uint64_t i = 1; i < input.size(); i++) { - if (i == 1) num_boxes = input[i]->shape()[0]; - if (num_boxes != input[i]->shape()[0]) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "Numbers of boxes do not match"); - - if (input[i]->type() != DataType::DE_FLOAT32) - return Status(StatusCode::kUnexpectedError, __LINE__, __FILE__, "boxes' type is not DE_FLOAT21"); - } - - // assume input Tensor vector in the order of [img, bbox_y_min, bbox_y_max, bbox_x_min, bbox_x_max] - CHECK_FAIL_RETURN_UNEXPECTED(input[0]->shape().Size() >= 2, "The shape of the first tensor is abnormal"); - int h_in = input[0]->shape()[0]; - int w_in = input[0]->shape()[1]; - - std::vector bounding_boxes; - for (int64_t i = 0; i < num_boxes; ++i) { - // bbox coordinates are floats relative to the image width and height - float y_min, y_max, x_min, x_max; - RETURN_IF_NOT_OK(input[1]->GetItemAt(&y_min, {i})); - RETURN_IF_NOT_OK(input[2]->GetItemAt(&y_max, {i})); - RETURN_IF_NOT_OK(input[3]->GetItemAt(&x_min, {i})); - RETURN_IF_NOT_OK(input[4]->GetItemAt(&x_max, {i})); - bounding_boxes.emplace_back(static_cast(x_min * w_in), static_cast(y_min * h_in), - static_cast((x_max - x_min) * w_in), static_cast((y_max - y_min) * h_in)); - } - cv::Rect output_box; - bool should_crop = false; - - // go over iterations, if no satisfying box found we return the original image - for (int32_t t = 0; t < max_attempts_; ++t) { - // try to generate random box - RETURN_IF_NOT_OK(GenerateRandomCropBox(h_in, w_in, aspect_ratio_, crop_ratio_lb_, crop_ratio_ub_, - box_gen_attempts_, // int maxIter, should not be needed here - &output_box, seed_)); - RETURN_IF_NOT_OK(CheckOverlapConstraint(output_box, - bounding_boxes, // have to change, should take tensor or add bbox logic - intersect_ratio_, &should_crop)); - if (should_crop) { - // found a box to crop - break; - } - } - // essentially we have to check this again at the end to return original tensor - if (should_crop) { - std::shared_ptr out; - RETURN_IF_NOT_OK(Crop(input[0], &out, output_box.x, output_box.y, output_box.width, output_box.height)); - output->push_back(out); - } else { - output->push_back(input[0]); - } - return Status::OK(); -} - -Status DistortBoundingBoxCropOp::OutputShape(const std::vector &inputs, - std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); - outputs.clear(); - TensorShape out = TensorShape{-1, -1}; - if (inputs[0].Rank() == 2) outputs.emplace_back(out); - if (inputs[0].Rank() == 3) outputs.emplace_back(out.AppendDim(inputs[0][2])); - if (!outputs.empty()) return Status::OK(); - return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); -} -Status DistortBoundingBoxCropOp::OutputType(const std::vector &inputs, std::vector &outputs) { - RETURN_IF_NOT_OK(TensorOp::OutputType(inputs, outputs)); - outputs[0] = inputs[0]; - return Status::OK(); -} -} // namespace dataset -} // namespace mindspore diff --git a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h b/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h deleted file mode 100644 index 749c166d59..0000000000 --- a/mindspore/ccsrc/dataset/kernels/image/distort_bounding_box_crop_op.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ -#define DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ - -#include -#include -#include -#include "dataset/core/tensor.h" -#include "dataset/kernels/tensor_op.h" -#include "dataset/util/status.h" - -namespace mindspore { -namespace dataset { -class DistortBoundingBoxCropOp : public TensorOp { - public: - // Default values, also used by python_bindings.cc - static const int32_t kDefMaxAttempts; - static const int32_t kDefBoxGenAttempts; - - // Constructor for DistortBoundingBoxCropOp - // @param max_attempts tries before the crop happens - // @param box_gen_attempts crop box generation attempts - // @param aspect_ratio aspect ratio of the generated crop box - // @param intersect_ratio area overlap ratio, condition for crop only if area over lap between the generated - // crop box has sufficient overlap with any 1 bounding box - // @param crop_ratio_lb the crop ratio lower bound - // @param crop_ratio_ub the crop ratio upper bound - // @param seed - DistortBoundingBoxCropOp(float aspect_ratio, float intersect_ratio, float crop_ratio_lb, float crop_ratio_ub, - int32_t max_attempts = kDefMaxAttempts, int32_t box_gen_attempts = kDefBoxGenAttempts); - - ~DistortBoundingBoxCropOp() override = default; - - void Print(std::ostream &out) const override { - out << "DistortBoundingBoxCropOp: " << max_attempts_ << " " << intersect_ratio_; - } - - Status Compute(const std::vector> &input, - std::vector> *output) override; - - uint32_t NumInput() override { return 5; } - Status OutputShape(const std::vector &inputs, std::vector &outputs) override; - Status OutputType(const std::vector &inputs, std::vector &outputs) override; - - private: - int32_t max_attempts_; - int32_t box_gen_attempts_; - float aspect_ratio_; - float intersect_ratio_; - float crop_ratio_lb_; - float crop_ratio_ub_; - std::mt19937 rnd_; - uint32_t seed_; -}; -} // namespace dataset -} // namespace mindspore - -#endif // DATASET_KERNELS_IMAGE_DISTORT_BOUNDING_BOX_CROP_OP_H_ diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc b/mindspore/ccsrc/dataset/kernels/image/image_utils.cc index 63c9bb2641..e4570b876d 100644 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.cc +++ b/mindspore/ccsrc/dataset/kernels/image/image_utils.cc @@ -636,76 +636,10 @@ Status AdjustHue(const std::shared_ptr &input, std::shared_ptr * return Status::OK(); } -Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, - cv::Rect *crop_box, uint32_t seed) { - try { - std::mt19937 rnd; - rnd.seed(GetSeed()); - if (input_height <= 0 || input_width <= 0 || ratio <= 0.0 || lb <= 0.0 || lb > ub) { - RETURN_STATUS_UNEXPECTED("Invalid inputs GenerateRandomCropBox"); - } - std::uniform_real_distribution rd_crop_ratio(lb, ub); - float crop_ratio; - int crop_width, crop_height; - bool crop_success = false; - int64_t input_area = input_height * input_width; - for (auto i = 0; i < max_itr; i++) { - crop_ratio = rd_crop_ratio(rnd); - crop_width = static_cast(std::round(std::sqrt(input_area * static_cast(crop_ratio) / ratio))); - crop_height = static_cast(std::round(crop_width * ratio)); - if (crop_width <= input_width && crop_height <= input_height) { - crop_success = true; - break; - } - } - if (crop_success == false) { - ratio = static_cast(input_height) / input_width; - crop_ratio = rd_crop_ratio(rnd); - crop_width = static_cast(std::lround(std::sqrt(input_area * static_cast(crop_ratio) / ratio))); - crop_height = static_cast(std::lround(crop_width * ratio)); - crop_height = (crop_height > input_height) ? input_height : crop_height; - crop_width = (crop_width > input_width) ? input_width : crop_width; - } - std::uniform_int_distribution<> rd_x(0, input_width - crop_width); - std::uniform_int_distribution<> rd_y(0, input_height - crop_height); - *crop_box = cv::Rect(rd_x(rnd), rd_y(rnd), crop_width, crop_height); - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("error in GenerateRandomCropBox."); - } -} - -Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector &bounding_boxes, - float min_intersect_ratio, bool *is_satisfied) { - try { - // not satisfied if the crop box contains no pixel - if (crop_box.area() < 1.0) { - *is_satisfied = false; - } - for (const auto &b_box : bounding_boxes) { - const float b_box_area = b_box.area(); - // not satisfied if the bounding box contains no pixel - if (b_box_area < 1.0) { - continue; - } - const float intersect_ratio = (crop_box & b_box).area() / b_box_area; - if (intersect_ratio >= min_intersect_ratio) { - *is_satisfied = true; - break; - } - } - return Status::OK(); - } catch (const cv::Exception &e) { - RETURN_STATUS_UNEXPECTED("error in CheckOverlapConstraint."); - } -} - Status Erase(const std::shared_ptr &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r, uint8_t fill_g, - uint8_t fill_b) { + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, uint8_t fill_r, + uint8_t fill_g, uint8_t fill_b) { try { - std::mt19937 rnd; - rnd.seed(GetSeed()); std::shared_ptr input_cv = CVTensor::AsCVTensor(input); if (input_cv->mat().data == nullptr || (input_cv->Rank() != 3 && input_cv->shape()[2] != 3)) { RETURN_STATUS_UNEXPECTED("bad CV Tensor input for erase"); @@ -731,8 +665,8 @@ Status Erase(const std::shared_ptr &input, std::shared_ptr *outp // rows in cv mat refers to the height of the cropped box // we determine h_start and w_start using two different distributions as erasing is used by two different // image augmentations. The bounds are also different in each case. - int32_t h_start = (bounded) ? height_distribution_bound(rnd) : (height_distribution_unbound(rnd) - box_height); - int32_t w_start = (bounded) ? width_distribution_bound(rnd) : (width_distribution_unbound(rnd) - box_width); + int32_t h_start = (bounded) ? height_distribution_bound(*rnd) : (height_distribution_unbound(*rnd) - box_height); + int32_t w_start = (bounded) ? width_distribution_bound(*rnd) : (width_distribution_unbound(*rnd) - box_width); int32_t max_width = (w_start + box_width > image_w) ? image_w : w_start + box_width; int32_t max_height = (h_start + box_height > image_h) ? image_h : h_start + box_height; @@ -744,9 +678,9 @@ Status Erase(const std::shared_ptr &input, std::shared_ptr *outp for (int x = h_start; x < max_height; x++) { if (random_color) { // fill each box with a random value - input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(rnd)); - input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(rnd)); - input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(rnd)); + input_img.at(cv::Point(y, x))[0] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[1] = static_cast(normal_distribution(*rnd)); + input_img.at(cv::Point(y, x))[2] = static_cast(normal_distribution(*rnd)); } else { input_img.at(cv::Point(y, x))[0] = fill_r; input_img.at(cv::Point(y, x))[1] = fill_g; diff --git a/mindspore/ccsrc/dataset/kernels/image/image_utils.h b/mindspore/ccsrc/dataset/kernels/image/image_utils.h index 51090fb9ea..394323974a 100644 --- a/mindspore/ccsrc/dataset/kernels/image/image_utils.h +++ b/mindspore/ccsrc/dataset/kernels/image/image_utils.h @@ -196,12 +196,6 @@ Status AdjustSaturation(const std::shared_ptr &input, std::shared_ptr &input, std::shared_ptr *output, const float &hue); -Status GenerateRandomCropBox(int input_height, int input_width, float ratio, float lb, float ub, int max_itr, - cv::Rect *crop_box, uint32_t seed = std::mt19937::default_seed); - -Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector &bounding_boxes, - float min_intersect_ratio, bool *is_satisfied); - // Masks out a random section from the image with set dimension // @param input: input Tensor // @param output: cutOut Tensor @@ -214,8 +208,8 @@ Status CheckOverlapConstraint(const cv::Rect &crop_box, const std::vector &input, std::shared_ptr *output, int32_t box_height, - int32_t box_width, int32_t num_patches, bool bounded, bool random_color, uint8_t fill_r = 0, - uint8_t fill_g = 0, uint8_t fill_b = 0); + int32_t box_width, int32_t num_patches, bool bounded, bool random_color, std::mt19937 *rnd, + uint8_t fill_r = 0, uint8_t fill_g = 0, uint8_t fill_b = 0); // Pads the input image and puts the padded image in the output // @param input: input Tensor From 6d3709ebd10c96195917d5c8e343747c2cc859f1 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 22 Apr 2020 21:00:32 -0400 Subject: [PATCH 117/142] fix batchnorm bug --- mindspore/nn/layer/normalization.py | 66 ++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 6456a3603d..07d08f9b2f 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -33,7 +33,6 @@ class _BatchNorm(Cell): @cell_attr_register def __init__(self, num_features, - group=1, eps=1e-5, momentum=0.9, affine=True, @@ -41,7 +40,8 @@ class _BatchNorm(Cell): beta_init='zeros', moving_mean_init='zeros', moving_var_init='ones', - use_batch_statistics=True): + use_batch_statistics=True, + group=1): super(_BatchNorm, self).__init__() if num_features < 1: raise ValueError("num_features must be at least 1") @@ -214,6 +214,25 @@ class BatchNorm1d(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [3, 16]), mindspore.float32) >>> net(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True): + super(BatchNorm1d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics) def _check_data_dim(self, x): if x.dim() != 2: pass @@ -266,6 +285,25 @@ class BatchNorm2d(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> net(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True): + super(BatchNorm2d, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics) def _check_data_dim(self, x): if x.dim() != 4: pass @@ -316,6 +354,30 @@ class GlobalBatchNorm(_BatchNorm): >>> input = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32) >>> global_bn_op(input) """ + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + gamma_init='ones', + beta_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=True, + group=1): + super(GlobalBatchNorm, self).__init__(num_features, + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics, + group) + self.group = check_int_positive(group) + if self.group <=1: + raise ValueError("the number of group must be greater than 1.") def _check_data_dim(self, x): if x.dim == 0: pass From a974e65bb6de05c94b35976760516ab0ea088f62 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 22 Apr 2020 21:06:14 -0400 Subject: [PATCH 118/142] fix batchnorm bug --- mindspore/nn/layer/normalization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index 07d08f9b2f..a25f640ccc 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -376,7 +376,7 @@ class GlobalBatchNorm(_BatchNorm): use_batch_statistics, group) self.group = check_int_positive(group) - if self.group <=1: + if self.group <= 1: raise ValueError("the number of group must be greater than 1.") def _check_data_dim(self, x): if x.dim == 0: From 68b45246512b529c455a1eb37f48c1a1e3e57cea Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 22 Apr 2020 21:34:04 -0400 Subject: [PATCH 119/142] fix batchnorm bug --- mindspore/nn/layer/normalization.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index a25f640ccc..b9e9d6ebb7 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -366,15 +366,15 @@ class GlobalBatchNorm(_BatchNorm): use_batch_statistics=True, group=1): super(GlobalBatchNorm, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - moving_mean_init, - moving_var_init, - use_batch_statistics, - group) + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics, + group) self.group = check_int_positive(group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.") From 9d6ff7ffd04777d1bf30f8b2ee23a6b54b3f2db3 Mon Sep 17 00:00:00 2001 From: zhaojichen Date: Wed, 22 Apr 2020 21:41:15 -0400 Subject: [PATCH 120/142] fix batchnorm bug --- mindspore/nn/layer/normalization.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mindspore/nn/layer/normalization.py b/mindspore/nn/layer/normalization.py index b9e9d6ebb7..3ef2381ba1 100644 --- a/mindspore/nn/layer/normalization.py +++ b/mindspore/nn/layer/normalization.py @@ -366,15 +366,15 @@ class GlobalBatchNorm(_BatchNorm): use_batch_statistics=True, group=1): super(GlobalBatchNorm, self).__init__(num_features, - eps, - momentum, - affine, - gamma_init, - beta_init, - moving_mean_init, - moving_var_init, - use_batch_statistics, - group) + eps, + momentum, + affine, + gamma_init, + beta_init, + moving_mean_init, + moving_var_init, + use_batch_statistics, + group) self.group = check_int_positive(group) if self.group <= 1: raise ValueError("the number of group must be greater than 1.") From 672244e0ac33648e997ab16d1a7c26b120c15b55 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Wed, 22 Apr 2020 14:43:19 +0800 Subject: [PATCH 121/142] add keep_bn_fp32 parameter --- .../pre_activate/ascend/ir_fusion/mul_addn_fusion.cc | 6 +++--- mindspore/nn/optim/optimizer.py | 2 +- mindspore/train/model.py | 12 +++++++++--- .../gtest_input/pre_activate/mul_addn_fusion_test.py | 2 +- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc index 83c58ab547..a5e4675c8f 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc @@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const auto prim = std::make_shared(kFusedMulAddNOpName); std::vector inputs = {NewValueNode(prim)}; inputs.push_back(mul->input(kMulInputNum - lossscale_input_index)); - inputs.push_back(addn->input(1)); + inputs.push_back(addn->input(2)); // scalar input should be 3rd input inputs.push_back(mul->input(lossscale_input_index)); auto fusion_node = graph->NewCNode(inputs); @@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const { VarPtr Z = std::make_shared(); VectorRef mul({prim::kPrimMul, X, Z}); - VectorRef addn({prim::kPrimAddN, Y, mul}); + VectorRef addn({prim::kPrimAddN, mul, Y}); return addn; } @@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode if (addn == nullptr || addn->inputs().size() != kAddNInputNum) { return nullptr; } - auto mul_anf = addn->input(2); + auto mul_anf = addn->input(1); if (mul_anf == nullptr) { return nullptr; } diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 72593e8001..bab539461e 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -177,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: - return op_add((gradient, weight * weight_decay)) + return op_add((weight * weight_decay, gradient)) return gradient diff --git a/mindspore/train/model.py b/mindspore/train/model.py index f4d1a324d1..698105889a 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -62,6 +62,7 @@ class Model: loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument. e.g. Use `loss_scale_manager=None` to set the value. + keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True. Examples: >>> class Net(nn.Cell): @@ -96,7 +97,10 @@ class Model: self._optimizer = optimizer self._loss_scale_manager = None self._loss_scale_manager_set = False + self._keep_bn_fp32 = True self._check_kwargs(kwargs) + if 'keep_batchnorm_fp32' in kwargs: + self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] if 'loss_scale_manager' in kwargs: self._loss_scale_manager = kwargs['loss_scale_manager'] self._loss_scale_manager_set = True @@ -112,7 +116,7 @@ class Model: def _check_kwargs(self, kwargs): for arg in kwargs: - if arg not in ['loss_scale_manager']: + if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: raise ValueError(f"Unsupport arg '{arg}'") def _build_train_network(self): @@ -124,12 +128,14 @@ class Model: self._optimizer, self._loss_fn, level=self._amp_level, - loss_scale_manager=self._loss_scale_manager) + loss_scale_manager=self._loss_scale_manager, + keep_batchnorm_fp32=self._keep_bn_fp32) else: network = amp.build_train_network(network, self._optimizer, self._loss_fn, - level=self._amp_level) + level=self._amp_level, + keep_batchnorm_fp32=self._keep_bn_fp32) elif self._loss_fn: network = nn.WithLossCell(network, self._loss_fn) # If need to check if loss_fn is not None, but optimizer is None diff --git a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py index e5b0a15387..8ce64109c6 100644 --- a/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py +++ b/tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py @@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag): @fns def before(a, b): res = mul(scalar, a) - res = addn((b, res)) + res = addn((res, b)) return res @fns From cc7b05e3ce1e49037d1e1115670b3b2079a57c26 Mon Sep 17 00:00:00 2001 From: VectorSL Date: Thu, 23 Apr 2020 09:54:48 +0800 Subject: [PATCH 122/142] fix cudnn type error --- mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc | 3 +++ mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h | 3 +++ mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h | 2 +- mindspore/train/amp.py | 2 +- 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc index 1b7318c511..69716e9165 100644 --- a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc @@ -26,5 +26,8 @@ MS_REG_GPU_KERNEL_ONE( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), TensorAddGpuFwdKernel, half) +MS_REG_GPU_KERNEL_ONE( + TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + TensorAddGpuFwdKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h index a203567aa8..4dfbf4c3d4 100644 --- a/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h @@ -71,6 +71,9 @@ class TensorAddGpuFwdKernel : public GpuKernel { bool Init(const CNodePtr &kernel_node) { InitResource(); cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; + if (cudnn_data_type_ == CUDNN_DATA_INT32) { + cudnn_data_type_ = CUDNN_DATA_FLOAT; + } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); if (input_num != 2) { MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; diff --git a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h index fd73f378d8..5c7153a172 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/bias_add_grad_gpu_kenel.h @@ -101,7 +101,7 @@ class BiasAddGradGpuKernel : public GpuKernel { cudnnSetTensorNdDescriptorEx(db_desc_, CUDNN_TENSOR_NCHW, cudnn_data_type_, SizeToInt(cudnn_dims), db_dims.get()), "cudnnSetTensorNdDescriptor failed"); CHECK_CUDNN_RET_WITH_EXCEPT( - cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, cudnn_data_type_, CUDNN_NOT_PROPAGATE_NAN, + cudnnSetReduceTensorDescriptor(op_desc_, CUDNN_REDUCE_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN, CUDNN_REDUCE_TENSOR_NO_INDICES, CUDNN_32BIT_INDICES), "cudnnSetReduceTensorDescriptor failed"); diff --git a/mindspore/train/amp.py b/mindspore/train/amp.py index 66e08874b2..917b4c3359 100644 --- a/mindspore/train/amp.py +++ b/mindspore/train/amp.py @@ -151,7 +151,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs): loss_scale = loss_scale_manager.get_loss_scale() update_cell = loss_scale_manager.get_update_cell() if update_cell is not None: - if not context.get_context("enable_ge"): + if not (context.get_context("enable_ge") or (context.get_context("device_target") == "GPU")): raise ValueError("Only `loss_scale_manager=None` and " "`loss_scale_manager=FixedLossScaleManager(drop_overflow_update=False)`" "are supported in current version. If you use `O2` option, please" From bd13f9ba339e2cdffdaf9a2d192a6aff5ca92e85 Mon Sep 17 00:00:00 2001 From: chang zherui <760161589@qq.com> Date: Thu, 23 Apr 2020 15:17:06 +0800 Subject: [PATCH 123/142] modify ResizeNearestNeighborV2D --- .../gpu/nn/fused_batch_norm_gpu_kernel.cc | 2 - .../gpu/nn/fused_batch_norm_gpu_kernel.h | 3 - mindspore/ccsrc/transform/op_declare.cc | 59 +++++++++---------- mindspore/ops/_grad/grad_nn_ops.py | 4 +- mindspore/ops/operations/_grad_ops.py | 4 +- mindspore/ops/operations/nn_ops.py | 8 +-- tests/ut/python/ops/test_ops.py | 2 +- 7 files changed, 35 insertions(+), 47 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc index 4ddc710a4c..91747d24d8 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.cc @@ -55,7 +55,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32) - .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), FusedBatchNormGpuKernel, float) MS_REG_GPU_KERNEL_ONE(BatchNorm, @@ -69,7 +68,6 @@ MS_REG_GPU_KERNEL_ONE(BatchNorm, .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16) - .AddOutputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), FusedBatchNormGpuKernel, half) } // namespace kernel diff --git a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h index 6f0c59e29a..5ca85f8e63 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/fused_batch_norm_gpu_kernel.h @@ -156,9 +156,6 @@ class FusedBatchNormGpuKernel : public GpuKernel { output_size_list_.push_back(para_size); // running variance output_size_list_.push_back(para_size); // save mean output_size_list_.push_back(para_size); // save variance - if (!is_train_) { - output_size_list_.push_back(para_size); // reserve - } return; } diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index d6ca3f4cbe..299ac4f44d 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -154,14 +154,14 @@ ATTR_MAP(BatchNorm) = {{"data_format", ATTR_DESC(data_format, AnyTraits())}, {"epsilon", ATTR_DESC(epsilon, AnyTraits())}, {"is_training", ATTR_DESC(is_training, AnyTraits())}}; @@ -266,11 +266,6 @@ INPUT_MAP(GatherV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(indices)}, {3, INPUT_D ATTR_MAP(GatherV2) = EMPTY_ATTR_MAP; OUTPUT_MAP(GatherV2) = {{0, OUTPUT_DESC(y)}}; -// ReduceSum -INPUT_MAP(ReduceSum) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}}; -ATTR_MAP(ReduceSum) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceSum) = {{0, OUTPUT_DESC(y)}}; - // ReduceSumD INPUT_MAP(ReduceSumD) = {{1, INPUT_DESC(x)}}; INPUT_ATTR_MAP(ReduceSumD) = { @@ -451,17 +446,17 @@ INPUT_MAP(Iou) = {{1, INPUT_DESC(bboxes)}, {2, INPUT_DESC(gtboxes)}}; ATTR_MAP(Iou) = {{"mode", ATTR_DESC(mode, AnyTraits())}}; OUTPUT_MAP(Iou) = {{0, OUTPUT_DESC(overlap)}}; -// ResizeNearestNeighborD -INPUT_MAP(ResizeNearestNeighborD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeNearestNeighborD) = { +// ResizeNearestNeighborV2D +INPUT_MAP(ResizeNearestNeighborV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeNearestNeighborV2D) = { {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborD) = {{0, OUTPUT_DESC(y)}}; +OUTPUT_MAP(ResizeNearestNeighborV2D) = {{0, OUTPUT_DESC(y)}}; -// ResizeNearestNeighborGrad -INPUT_MAP(ResizeNearestNeighborGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; -ATTR_MAP(ResizeNearestNeighborGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeNearestNeighborGrad) = {{0, OUTPUT_DESC(y)}}; +// ResizeNearestNeighborV2Grad +INPUT_MAP(ResizeNearestNeighborV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(size)}}; +ATTR_MAP(ResizeNearestNeighborV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeNearestNeighborV2Grad) = {{0, OUTPUT_DESC(y)}}; // ApplyAdam INPUT_MAP(ApplyAdam) = {{1, INPUT_DESC(var)}, {2, INPUT_DESC(m)}, {3, INPUT_DESC(v)}, @@ -486,17 +481,17 @@ INPUT_MAP(Relu6Grad) = {{1, INPUT_DESC(gradients)}, {2, INPUT_DESC(features)}}; ATTR_MAP(Relu6Grad) = EMPTY_ATTR_MAP; OUTPUT_MAP(Relu6Grad) = {{0, OUTPUT_DESC(backprops)}}; -// ResizeBilinearGrad -INPUT_MAP(ResizeBilinearGrad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; -ATTR_MAP(ResizeBilinearGrad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearGrad) = {{0, OUTPUT_DESC(y)}}; +// ResizeBilinearV2Grad +INPUT_MAP(ResizeBilinearV2Grad) = {{1, INPUT_DESC(grads)}, {2, INPUT_DESC(original_image)}}; +ATTR_MAP(ResizeBilinearV2Grad) = {{"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; +OUTPUT_MAP(ResizeBilinearV2Grad) = {{0, OUTPUT_DESC(y)}}; -// ResizeBilinearD -INPUT_MAP(ResizeBilinearD) = {{1, INPUT_DESC(x)}}; -ATTR_MAP(ResizeBilinearD) = { +// ResizeBilinearV2D +INPUT_MAP(ResizeBilinearV2D) = {{1, INPUT_DESC(x)}}; +ATTR_MAP(ResizeBilinearV2D) = { {"size", ATTR_DESC(size, AnyTraits>(), AnyTraits>())}, {"align_corners", ATTR_DESC(align_corners, AnyTraits())}}; -OUTPUT_MAP(ResizeBilinearD) = {{0, OUTPUT_DESC(y)}}; +OUTPUT_MAP(ResizeBilinearV2D) = {{0, OUTPUT_DESC(y)}}; // ZerosLike INPUT_MAP(ZerosLike) = {{1, INPUT_DESC(x)}}; @@ -609,10 +604,12 @@ ATTR_MAP(ArgMinWithValue) = {{"axis", ATTR_DESC(dimension, AnyTraits())}, {"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; OUTPUT_MAP(ArgMinWithValue) = {{0, OUTPUT_DESC(indice)}, {1, OUTPUT_DESC(values)}}; -// ReduceAll -INPUT_MAP(ReduceAll) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(axes)}}; -ATTR_MAP(ReduceAll) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; -OUTPUT_MAP(ReduceAll) = {{0, OUTPUT_DESC(y)}} +// ReduceAllD +INPUT_MAP(ReduceAllD) = {{1, INPUT_DESC(x)}}; +INPUT_ATTR_MAP(ReduceAllD) = { + {2, ATTR_DESC(axis, AnyTraits>(), AnyTraits>())}}; +ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; +OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; // ReduceMeanD INPUT_MAP(ReduceMeanD) = {{1, INPUT_DESC(x)}}; diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index e43d3d5d3a..6db059a7bb 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -356,12 +356,10 @@ def get_bprop_batch_norm(self): if is_training: saved_reserve_1 = out[3] saved_reserve_2 = out[4] - saved_reserve_3 = out[5] else: saved_reserve_1 = mean saved_reserve_2 = variance - saved_reserve_3 = variance - out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2, saved_reserve_3) + out = input_grad(dout[0], x, scale, saved_reserve_1, saved_reserve_2) dx = out[0] dscale = out[1] dbias = out[2] diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index c29832dcb7..9f277908ed 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -69,11 +69,11 @@ class BatchNormGrad(PrimitiveWithInfer): self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) self.add_prim_attr('data_format', "NCHW") - def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) - def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): + def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 49145fb072..93359c7dd9 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -537,7 +537,6 @@ class BatchNorm(PrimitiveWithInfer): - **updated_bias** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_1** (Tensor) - Tensor of shape :math:`(C,)`. - **reserve_space_2** (Tensor) - Tensor of shape :math:`(C,)`. - - **reserve_space_3** (Tensor) - Tensor of shape :math:`(C,)`. """ @prim_attr_register @@ -546,8 +545,7 @@ class BatchNorm(PrimitiveWithInfer): validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") self.init_prim_io_names(inputs=['x', 'scale', 'offset', 'mean', 'variance'], - outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2', - 'reserve_space_3']) + outputs=['y', 'batch_mean', 'batch_variance', 'reserve_space_1', 'reserve_space_2']) def infer_shape(self, input_x, scale, bias, mean, variance): validator.check_integer("scale rank", len(scale), 1, Rel.EQ, self.name) @@ -557,7 +555,7 @@ class BatchNorm(PrimitiveWithInfer): validator.check_integer("mean rank", len(mean), 1, Rel.EQ, self.name) validator.check("mean shape", mean, "variance shape", variance, Rel.EQ, self.name) validator.check("mean shape", mean, "scale shape", scale, Rel.EQ, self.name) - return (input_x, scale, scale, scale, scale, scale) + return (input_x, scale, scale, scale, scale) def infer_dtype(self, input_x, scale, bias, mean, variance): validator.check_tensor_type_same({"input_x": input_x}, [mstype.float16, mstype.float32], self.name) @@ -570,7 +568,7 @@ class BatchNorm(PrimitiveWithInfer): else: args_moving = {"mean": mean, "variance": variance} validator.check_tensor_type_same(args_moving, [mstype.float16, mstype.float32], self.name) - return (input_x, scale, bias, input_x, input_x, input_x) + return (input_x, scale, bias, input_x, input_x) class Conv2D(PrimitiveWithInfer): diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 8b14ea2366..1dea7b6502 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -671,7 +671,7 @@ test_case_nn_ops = [ 'skip': []}), ('BatchNormGrad', { 'block': G.BatchNormGrad(), - 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], + 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), ('ApplyMomentum', { From 0949ea19afb02ecfa986b0cc29d950141bf7dd28 Mon Sep 17 00:00:00 2001 From: chang zherui <760161589@qq.com> Date: Thu, 23 Apr 2020 19:58:32 +0800 Subject: [PATCH 124/142] modify ReduceAllD --- mindspore/ccsrc/transform/op_declare.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 299ac4f44d..718751e9f5 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -154,6 +154,7 @@ ATTR_MAP(BatchNorm) = {{"data_format", ATTR_DESC(data_format, AnyTraits>(), AnyTraits>())}}; + {2, ATTR_DESC(axes, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(ReduceAllD) = {{"keep_dims", ATTR_DESC(keep_dims, AnyTraits())}}; OUTPUT_MAP(ReduceAllD) = {{0, OUTPUT_DESC(y)}}; From c6f5efaab279e78d36c79e34738968f4d44442eb Mon Sep 17 00:00:00 2001 From: chang zherui <760161589@qq.com> Date: Thu, 23 Apr 2020 21:20:14 +0800 Subject: [PATCH 125/142] modify ge --- graphengine | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphengine b/graphengine index cfc99f95f7..43a715bc46 160000 --- a/graphengine +++ b/graphengine @@ -1 +1 @@ -Subproject commit cfc99f95f722918025b0eaeb93440d92265f09fe +Subproject commit 43a715bc461fd70b7837051a2f47f0a1b19c5859 From 1587eef5299ade3496aaaf79b64a2735af1da7b2 Mon Sep 17 00:00:00 2001 From: chang zherui <760161589@qq.com> Date: Thu, 23 Apr 2020 21:29:43 +0800 Subject: [PATCH 126/142] modify Conv2DBackpropFilterD pad --- mindspore/ccsrc/transform/op_declare.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 718751e9f5..3e6b029f6b 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -742,7 +742,7 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { {3, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(Conv2DBackpropFilterD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, }; OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; From 1c39c52f72f35b5ed16623b1ec75c6fcbadb2733 Mon Sep 17 00:00:00 2001 From: chang zherui <760161589@qq.com> Date: Thu, 23 Apr 2020 21:39:03 +0800 Subject: [PATCH 127/142] asd --- mindspore/ccsrc/transform/op_declare.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 3e6b029f6b..718751e9f5 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -742,7 +742,7 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { {3, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(Conv2DBackpropFilterD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, }; OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; From ee9eb86901c3f0a163582ed0e63b30ac2baec1f5 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 11:07:18 +0800 Subject: [PATCH 128/142] Conv2d and backprops --- mindspore/ccsrc/transform/op_declare.cc | 6 ++ mindspore/ops/operations/_grad_ops.py | 109 ++++++++++++++++------ mindspore/ops/operations/nn_ops.py | 119 +++++++++++++++++------- 3 files changed, 171 insertions(+), 63 deletions(-) diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 718751e9f5..b8aec05c3b 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -722,6 +722,8 @@ ATTR_MAP(Conv2D) = { {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, }; OUTPUT_MAP(Conv2D) = {{0, OUTPUT_DESC(y)}}; @@ -733,6 +735,8 @@ ATTR_MAP(Conv2DBackpropInputD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, }; OUTPUT_MAP(Conv2DBackpropInputD) = {{0, OUTPUT_DESC(y)}}; @@ -744,6 +748,8 @@ ATTR_MAP(Conv2DBackpropFilterD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, + {"data_format", ATTR_DESC(data_format, AnyTraits())}, + {"group", ATTR_DESC(groups, AnyTraits())}, }; OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 9f277908ed..7c118994c6 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -137,9 +137,9 @@ class ConcatOffset(PrimitiveWithInfer): return out -class Conv2DBackpropFilter(PrimitiveWithInfer): +class Conv2DBackpropInput(PrimitiveWithInfer): """ - Computes the gradients of convolution with respect to the filter. + Computes the gradients of convolution with respect to the input. Args: out_channel (int): The dimensionality of the output space. @@ -147,9 +147,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". pad (int): The pad value to fill. Default: 0. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , - 2 deconvolution, 3 depthwise convolution. Default: 1. - stride (tuple): The stride to apply conv filter. Default: (1, 1). - dilation (tuple): Specifies the dilation rate to use for dilated convolution. Default: (1, 1, 1, 1). + 2 deconvolution, 3 depthwise convolution. Default: 1. + stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1. + dilation (Union[int. tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1. group (int): Splits input into groups. Default: 1. Returns: @@ -162,38 +162,91 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): kernel_size, pad_mode="valid", pad=0, - pad_list=(0, 0, 0, 0), + pad_list=None, mode=1, - stride=(1, 1), - dilation=(1, 1, 1, 1), + stride=1, + dilation=1, group=1): - """init Convolution""" - self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output']) - self.out_channel = out_channel - self.kernel_size = kernel_size - self.mode = mode + """init Conv2DBackpropInput""" + self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) + self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ + (not isinstance(self.kernel_size[1], int)) or \ + self.kernel_size[0] < 1 or self.kernel_size[1] < 1: + raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or " + f"a tuple of two positive int numbers, but got {kernel_size}") + self.stride = validator.check_type('stride', stride, (int, tuple)) + if isinstance(stride, int): + self.stride = (stride, stride) + elif isinstance(stride, tuple) and len(stride) == 4: + self.stride = (stride[2], stride[3]) + if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \ + self.stride[0] < 1 or self.stride[1] < 1: + raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or " + f"a tuple of two or four positive int numbers, but got {stride}") + self.add_prim_attr('stride', self.stride) + self.dilation = validator.check_type('dilation', dilation, (tuple, int)) + if isinstance(dilation, int): + self.dilation = (1, 1, dilation, dilation) + elif len(dilation) == 2: + self.dilation = (1, 1, dilation[0], dilation[1]) + if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ + (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ + (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ + (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): + raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or " + f"a tuple of two or four positive int numbers, but got {dilation}") + self.add_prim_attr('dilation', self.dilation) + validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) + validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) + self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) + self.group = validator.check_integer('group', group, 0, Rel.GT) pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) - self.pad = pad - if isinstance(stride, tuple) and len(stride) == 4: - self.stride = (stride[2], stride[3]) - self.add_prim_attr('stride', self.stride) - self.dilation = dilation - self.group = group self.add_prim_attr('data_format', "NCHW") + if pad_list: + self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list) - def __infer__(self, doutput, x, w_size): - w_size_v = w_size['value'] - validator.check_type('w_size', w_size_v, [tuple]) - for i, dim_len in enumerate(w_size_v): - validator.check_type("w_size[%d]" % i, dim_len, [int]) - validator.check_typename('x_dtype', x['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'x_dtype', x['dtype']) + def __infer__(self, doutput, w, x_size): + x_size_v = x_size['value'] + validator.check_type('x_size', x_size_v, [tuple]) + for i, dim_len in enumerate(x_size_v): + validator.check_type("x_size[%d]" % i, dim_len, [int]) + validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) + validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype']) + + # infer shape + dout_shape = doutput['shape'] + kernel_h = self.kernel_size[0] + kernel_w = self.kernel_size[1] + stride_h = self.stride[0] + stride_w = self.stride[1] + # default pad mode is valid + pad_list = (0, 0, 0, 0) + if self.pad_list: + pad_list = tuple(self.pad_list) + elif self.pad_mode == "SAME": + pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + kernel_h - x_size_v[2]) + pad_top = math.floor(pad_needed_h / 2) + pad_bottom = pad_needed_h - pad_top + + pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + kernel_w - x_size_v[3]) + pad_left = math.floor(pad_needed_w / 2) + pad_right = pad_needed_w - pad_left + pad_list = (pad_top, pad_bottom, pad_left, pad_right) + elif self.pad_mode == 'PAD': + pad_list = (self.pad,) * 4 + self.add_prim_attr('pad_list', pad_list) out = { 'value': None, - 'shape': w_size_v, + 'shape': x_size_v, 'dtype': doutput['dtype'], - } + } return out diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 93359c7dd9..00af0155ae 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -637,28 +637,53 @@ class Conv2D(PrimitiveWithInfer): group=1): """init Conv2D""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name) + self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ + (not isinstance(self.kernel_size[1], int)) or \ + self.kernel_size[0] < 1 or self.kernel_size[1] < 1: + raise ValueError(f"The \'kernel_size\' of \'Conv2D\' should be an positive int number or " + f"a tuple of two positive int numbers, but got {kernel_size}") + self.stride = validator.check_type('stride', stride, (int, tuple)) + if isinstance(stride, int): + self.stride = (stride, stride) + if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ + (not isinstance(self.stride[1], int)) or \ + self.stride[0] < 1 or self.stride[1] < 1: + raise ValueError(f"The \'stride\' of \'Conv2D\' should be an positive int number or " + f"a tuple of two positive int numbers, but got {stride}") self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) + self.dilation = validator.check_type('dilation', dilation, (tuple, int)) + if isinstance(dilation, int): + self.dilation = (1, 1, dilation, dilation) + elif len(dilation) == 2: + self.dilation = (1, 1, dilation[0], dilation[1]) + if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ + (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ + (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ + (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): + raise ValueError(f"The \'dilation\' of \'Conv2D\' should be an positive int number or " + f"a tuple of two or four positive int numbers, but got {dilation}") self.add_prim_attr('dilation', self.dilation) - validator.check_value_type('pad', pad, (int,), self.name) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) - self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) + validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) + self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) if self.pad_mode == 'pad': - validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) + validator.check_integer('pad', self.pad, 0, Rel.GE) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) self.add_prim_attr('data_format', "NCHW") - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) - self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) + self.group = validator.check_integer('group', group, 0, Rel.GT) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) - validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) - validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) - validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) - validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) + validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) + validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) + validator.check_param_equal("x_shape[1]", x_shape[1] // self.group, "w_shape[1]", w_shape[1]) + validator.check_param_equal('out_channel', self.out_channel, 'w_shape[0]', w_shape[0]) + validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -700,9 +725,10 @@ class Conv2D(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x': x_dtype, 'w': w_dtype} - valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) + args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} + validator.check_subclass('input', x_dtype, mstype.tensor) + validator.check_subclass('weight', w_dtype, mstype.tensor) + validator.check_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) return x_dtype @@ -1082,33 +1108,56 @@ class Conv2DBackpropInput(PrimitiveWithInfer): group=1): """init Conv2DBackpropInput""" self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) - self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) + self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) + if isinstance(kernel_size, int): + self.kernel_size = (kernel_size, kernel_size) + if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ + (not isinstance(self.kernel_size[1], int)) or \ + self.kernel_size[0] < 1 or self.kernel_size[1] < 1: + raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or " + f"a tuple of two positive int numbers, but got {kernel_size}") + self.stride = validator.check_type('stride', stride, (int, tuple)) + if isinstance(stride, int): + self.stride = (stride, stride) + elif isinstance(stride, tuple) and len(stride) == 4: + self.stride = (stride[2], stride[3]) + if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \ + self.stride[0] < 1 or self.stride[1] < 1: + raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or " + f"a tuple of two or four positive int numbers, but got {stride}") self.add_prim_attr('stride', self.stride) - self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) + self.dilation = validator.check_type('dilation', dilation, (tuple, int)) + if isinstance(dilation, int): + self.dilation = (1, 1, dilation, dilation) + elif len(dilation) == 2: + self.dilation = (1, 1, dilation[0], dilation[1]) + if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ + (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ + (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ + (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): + raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or " + f"a tuple of two or four positive int numbers, but got {dilation}") self.add_prim_attr('dilation', self.dilation) - validator.check_value_type('pad', pad, (int,), self.name) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) - self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) + validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) + validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) + self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) + self.group = validator.check_integer('group', group, 0, Rel.GT) pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) - self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.add_prim_attr('data_format', "NCHW") if pad_list: - for x in pad_list: - validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name) - self.pad_list = pad_list + self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list) def __infer__(self, doutput, w, x_size): x_size_v = x_size['value'] - validator.check_value_type('x_size', x_size_v, [tuple], self.name) + validator.check_type('x_size', x_size_v, [tuple]) for i, dim_len in enumerate(x_size_v): - validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) - args = {'doutput': doutput['dtype'], 'w': w['dtype']} - valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] - validator.check_tensor_type_same(args, valid_types, self.name) + validator.check_type("x_size[%d]" % i, dim_len, [int]) + validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) + validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype']) # infer shape dout_shape = doutput['shape'] From f3ed97a7f028b0ba5e3f9a106af3d04cfe3119a7 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 11:46:06 +0800 Subject: [PATCH 129/142] fix conv2d --- mindspore/ops/operations/nn_ops.py | 121 +++++++++-------------------- 1 file changed, 36 insertions(+), 85 deletions(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 00af0155ae..9750549dc5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -637,53 +637,28 @@ class Conv2D(PrimitiveWithInfer): group=1): """init Conv2D""" self.init_prim_io_names(inputs=['x', 'w'], outputs=['output']) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or \ - (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {stride}") + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name) self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2D\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) if self.pad_mode == 'pad': - validator.check_integer('pad', self.pad, 0, Rel.GE) + validator.check_integer('pad', self.pad, 0, Rel.GE, self.name) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) self.add_prim_attr('data_format', "NCHW") - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.group = validator.check_integer('group', group, 0, Rel.GT) + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) def infer_shape(self, x_shape, w_shape): - validator.check_integer("weight_shape", len(w_shape), 4, Rel.EQ) - validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ) - validator.check_param_equal("x_shape[1]", x_shape[1] // self.group, "w_shape[1]", w_shape[1]) - validator.check_param_equal('out_channel', self.out_channel, 'w_shape[0]', w_shape[0]) - validator.check_param_equal('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4])) + validator.check_integer("weight rank", len(w_shape), 4, Rel.EQ, self.name) + validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name) + validator.check("x_shape[1] / group", x_shape[1] // self.group, "w_shape[1]", w_shape[1], Rel.EQ, self.name) + validator.check('out_channel', self.out_channel, 'w_shape[0]', w_shape[0], Rel.EQ, self.name) + validator.check('kernel_size', self.kernel_size, 'w_shape[2:4]', tuple(w_shape[2:4]), Rel.EQ, self.name) kernel_size_h = w_shape[2] kernel_size_w = w_shape[3] @@ -725,10 +700,9 @@ class Conv2D(PrimitiveWithInfer): return out_shape def infer_dtype(self, x_dtype, w_dtype): - args = {'x_dtype': x_dtype, 'w_dtype': w_dtype} - validator.check_subclass('input', x_dtype, mstype.tensor) - validator.check_subclass('weight', w_dtype, mstype.tensor) - validator.check_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) + args = {'x': x_dtype, 'w': w_dtype} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) return x_dtype @@ -1108,56 +1082,33 @@ class Conv2DBackpropInput(PrimitiveWithInfer): group=1): """init Conv2DBackpropInput""" self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - elif isinstance(stride, tuple) and len(stride) == 4: - self.stride = (stride[2], stride[3]) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {stride}") + self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) + self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.add_prim_attr('stride', self.stride) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") + self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) - self.group = validator.check_integer('group', group, 0, Rel.GT) + validator.check_value_type('pad', pad, (int,), self.name) + self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad'], self.name) + self.pad = validator.check_pad_value_by_mode(pad_mode, pad, self.name) pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) + self.mode = validator.check_integer('mode', mode, 1, Rel.EQ, self.name) + self.group = validator.check_integer('group', group, 0, Rel.GT, self.name) self.add_prim_attr('data_format', "NCHW") if pad_list: - self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list) + for x in pad_list: + validator.check_integer('element of pad_list', x, 0, Rel.GE, self.name) + self.pad_list = pad_list def __infer__(self, doutput, w, x_size): x_size_v = x_size['value'] - validator.check_type('x_size', x_size_v, [tuple]) + validator.check_value_type('x_size', x_size_v, [tuple], self.name) for i, dim_len in enumerate(x_size_v): - validator.check_type("x_size[%d]" % i, dim_len, [int]) - validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype']) + validator.check_value_type("x_size[%d]" % i, dim_len, [int], self.name) + args = {'doutput': doutput['dtype'], 'w': w['dtype']} + valid_types = [mstype.int8, mstype.int32, mstype.float16, mstype.float32] + validator.check_tensor_type_same(args, valid_types, self.name) # infer shape dout_shape = doutput['shape'] @@ -1677,7 +1628,7 @@ class LayerNorm(Primitive): `Layer Normalization `_. .. math:: - y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta + y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. From 366b6d6803dac354d80da5562676a4581bf29a58 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 11:57:33 +0800 Subject: [PATCH 130/142] fix conv2d 2 --- mindspore/ops/operations/_grad_ops.py | 107 +++++++------------------- 1 file changed, 27 insertions(+), 80 deletions(-) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 7c118994c6..b54d515b7a 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -137,9 +137,9 @@ class ConcatOffset(PrimitiveWithInfer): return out -class Conv2DBackpropInput(PrimitiveWithInfer): +class Conv2DBackpropFilter(PrimitiveWithInfer): """ - Computes the gradients of convolution with respect to the input. + Computes the gradients of convolution with respect to the filter. Args: out_channel (int): The dimensionality of the output space. @@ -147,9 +147,9 @@ class Conv2DBackpropInput(PrimitiveWithInfer): pad_mode (str): "valid", "same", "pad" the mode to fill padding. Default: "valid". pad (int): The pad value to fill. Default: 0. mode (int): 0 Math convolutiuon, 1 cross-correlation convolution , - 2 deconvolution, 3 depthwise convolution. Default: 1. - stride (Union[int. tuple[int]]): The stride to apply conv filter. Default: 1. - dilation (Union[int. tuple[int]]): Specifies the dilation rate to use for dilated convolution. Default: 1. + 2 deconvolution, 3 depthwise convolution. Default: 1. + stride (tuple): The stride to apply conv filter. Default: (1, 1). + dilation (tuple): Specifies the dilation rate to use for dilated convolution. Default: (1, 1, 1, 1). group (int): Splits input into groups. Default: 1. Returns: @@ -162,89 +162,36 @@ class Conv2DBackpropInput(PrimitiveWithInfer): kernel_size, pad_mode="valid", pad=0, - pad_list=None, + pad_list=(0, 0, 0, 0), mode=1, - stride=1, - dilation=1, + stride=(1, 1), + dilation=(1, 1, 1, 1), group=1): - """init Conv2DBackpropInput""" - self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) - self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT) - self.kernel_size = validator.check_type('kernel_size', kernel_size, (int, tuple)) - if isinstance(kernel_size, int): - self.kernel_size = (kernel_size, kernel_size) - if len(self.kernel_size) != 2 or (not isinstance(self.kernel_size[0], int)) or \ - (not isinstance(self.kernel_size[1], int)) or \ - self.kernel_size[0] < 1 or self.kernel_size[1] < 1: - raise ValueError(f"The \'kernel_size\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two positive int numbers, but got {kernel_size}") - self.stride = validator.check_type('stride', stride, (int, tuple)) - if isinstance(stride, int): - self.stride = (stride, stride) - elif isinstance(stride, tuple) and len(stride) == 4: - self.stride = (stride[2], stride[3]) - if len(self.stride) != 2 or (not isinstance(self.stride[0], int)) or (not isinstance(self.stride[1], int)) or \ - self.stride[0] < 1 or self.stride[1] < 1: - raise ValueError(f"The \'stride\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {stride}") - self.add_prim_attr('stride', self.stride) - self.dilation = validator.check_type('dilation', dilation, (tuple, int)) - if isinstance(dilation, int): - self.dilation = (1, 1, dilation, dilation) - elif len(dilation) == 2: - self.dilation = (1, 1, dilation[0], dilation[1]) - if len(self.dilation) != 4 or (not isinstance(self.dilation[0], int) or self.dilation[0] < 1) or \ - (not isinstance(self.dilation[1], int) or self.dilation[1] < 1) or \ - (not isinstance(self.dilation[2], int) or self.dilation[2] < 1) or \ - (not isinstance(self.dilation[3], int) or self.dilation[3] < 1): - raise ValueError(f"The \'dilation\' of \'Conv2DBackpropInput\' should be an positive int number or " - f"a tuple of two or four positive int numbers, but got {dilation}") - self.add_prim_attr('dilation', self.dilation) - validator.equal('type of pad', type(pad), 'not bool', not isinstance(pad, bool)) - validator.equal('type of pad', type(pad), 'int', isinstance(pad, int)) - self.pad_mode = validator.check_string('pad_mode', pad_mode, ['valid', 'same', 'pad']) - self.pad = validator.check_pad_value_by_mode(self.__class__.__name__, pad_mode, pad) - self.mode = validator.check_integer('mode', mode, 1, Rel.EQ) - self.group = validator.check_integer('group', group, 0, Rel.GT) + """init Convolution""" + self.init_prim_io_names(inputs=['out_backprop', 'input', 'filter_sizes'], outputs=['output']) + self.out_channel = out_channel + self.kernel_size = kernel_size + self.mode = mode pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) + self.pad = pad + if isinstance(stride, tuple) and len(stride) == 4: + self.stride = (stride[2], stride[3]) + self.add_prim_attr('stride', self.stride) + self.dilation = dilation + self.group = group self.add_prim_attr('data_format', "NCHW") - if pad_list: - self.pad_lsit = (validator.check_integer('pad_list', x, 0, Rel.GE) for x in pad_list) - def __infer__(self, doutput, w, x_size): - x_size_v = x_size['value'] - validator.check_type('x_size', x_size_v, [tuple]) - for i, dim_len in enumerate(x_size_v): - validator.check_type("x_size[%d]" % i, dim_len, [int]) - validator.check_typename('w_dtype', w['dtype'], [mstype.int8, mstype.int32, mstype.float16, mstype.float32]) - validator.check_two_types_same('doutput_dtype', doutput['dtype'], 'w_dtype', w['dtype']) - - # infer shape - dout_shape = doutput['shape'] - kernel_h = self.kernel_size[0] - kernel_w = self.kernel_size[1] - stride_h = self.stride[0] - stride_w = self.stride[1] - # default pad mode is valid - pad_list = (0, 0, 0, 0) - if self.pad_list: - pad_list = tuple(self.pad_list) - elif self.pad_mode == "SAME": - pad_needed_h = max(0, (dout_shape[2] - 1) * stride_h + kernel_h - x_size_v[2]) - pad_top = math.floor(pad_needed_h / 2) - pad_bottom = pad_needed_h - pad_top - - pad_needed_w = max(0, (dout_shape[3] - 1) * stride_w + kernel_w - x_size_v[3]) - pad_left = math.floor(pad_needed_w / 2) - pad_right = pad_needed_w - pad_left - pad_list = (pad_top, pad_bottom, pad_left, pad_right) - elif self.pad_mode == 'PAD': - pad_list = (self.pad,) * 4 - self.add_prim_attr('pad_list', pad_list) + def __infer__(self, doutput, x, w_size): + w_size_v = w_size['value'] + validator.check_value_type('w_size', w_size_v, [tuple], self.name) + for i, dim_len in enumerate(w_size_v): + validator.check_value_type("w_size[%d]" % i, dim_len, [int], self.name) + args = {"x": x['dtype'], "doutput": doutput['dtype']} + validator.check_tensor_type_same(args, [mstype.int8, mstype.int32, mstype.float16, mstype.float32], self.name) out = { 'value': None, - 'shape': x_size_v, + 'shape': w_size_v, 'dtype': doutput['dtype'], } return out From 61c41614539bbb7d8ccea232e4ab45c993148c3a Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 14:31:07 +0800 Subject: [PATCH 131/142] fix depthwise conv2d --- mindspore/ccsrc/transform/op_declare.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index b8aec05c3b..377403cc89 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -756,9 +756,9 @@ OUTPUT_MAP(Conv2DBackpropFilterD) = {{0, OUTPUT_DESC(y)}}; // DepthwiseConv2D INPUT_MAP(DepthwiseConv2D) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(filter)}}; ATTR_MAP(DepthwiseConv2D) = { - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, }; OUTPUT_MAP(DepthwiseConv2D) = {{0, OUTPUT_DESC(y)}}; @@ -768,9 +768,9 @@ INPUT_MAP(DepthwiseConv2DBackpropInputD) = {{2, INPUT_DESC(filter)}, {3, INPUT_D INPUT_ATTR_MAP(DepthwiseConv2DBackpropInputD) = { {1, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(DepthwiseConv2DBackpropInputD) = { - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, }; OUTPUT_MAP(DepthwiseConv2DBackpropInputD) = {{0, OUTPUT_DESC(input_grad)}}; @@ -779,9 +779,9 @@ INPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{1, INPUT_DESC(input)}, {3, INPUT_D INPUT_ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { {2, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(DepthwiseConv2DBackpropFilterD) = { - {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, + {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, {"pads", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"dilation", ATTR_DESC(dilations, "pad", AnyTraits>())}, + {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, }; OUTPUT_MAP(DepthwiseConv2DBackpropFilterD) = {{0, OUTPUT_DESC(filter_grad)}}; From 75b49640433ee3a6ba8bbfdb6bca270c3cd94270 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 14:51:02 +0800 Subject: [PATCH 132/142] update fanrui validator related --- mindspore/_checkparam.py | 14 +- mindspore/ops/operations/_grad_ops.py | 269 ++++++++++++-------------- 2 files changed, 126 insertions(+), 157 deletions(-) diff --git a/mindspore/_checkparam.py b/mindspore/_checkparam.py index 3543f58cf5..707ca748b4 100644 --- a/mindspore/_checkparam.py +++ b/mindspore/_checkparam.py @@ -206,8 +206,8 @@ class Validator: def _check_tensor_type(arg): arg_key, arg_val = arg elem_type = arg_val - type_names = [] if not elem_type in valid_values: + type_names = [] for t in valid_values: type_names.append(str(t)) types_info = '[' + ", ".join(type_names) + ']' @@ -304,10 +304,10 @@ class Validator: type_names = [get_typename(t) for t in valid_types] msg_prefix = f'For \'{prim_name}\' the' if prim_name else 'The' if len(valid_types) == 1: - raise ValueError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' - f' but got {get_typename(arg_type)}.') - raise ValueError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' - f' but got {get_typename(arg_type)}.') + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be {type_names[0]},' + f' but got {get_typename(arg_type)}.') + raise TypeError(f'{msg_prefix} type of `{arg_name}` should be one of {type_names},' + f' but got {get_typename(arg_type)}.') @staticmethod def check_float_legal_value(arg_name, arg_value, prim_name): @@ -417,8 +417,8 @@ class ParamValidator: """func for raising error message when check failed""" type_names = [t.__name__ for t in valid_types] num_types = len(valid_types) - raise ValueError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' - f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') + raise TypeError(f'The type of `{arg_name}` should be {"one of " if num_types > 1 else ""}' + f'{type_names if num_types > 1 else type_names[0]}, but got {type(arg_value).__name__}.') if isinstance(arg_value, type(mstype.tensor)): arg_value = arg_value.element_type() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index b54d515b7a..07857ca27b 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -18,8 +18,7 @@ from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register -from ..._checkparam import ParamValidator as validator -from ..._checkparam import Rel, check_int_positive, check_bool +from ..._checkparam import Validator as validator, Rel from .._utils import _get_concat_offset from ...common import dtype as mstype @@ -51,12 +50,12 @@ class ACosGrad(PrimitiveWithInfer): """init ACosGrad""" def infer_shape(self, x, dout): - validator.check_param_equal("x", x, "dout", dout) + validator.check("x shape", x, "dout shape", dout, Rel.EQ, self.name) return x def infer_dtype(self, x, dout): args = {"x": x, "dout": dout} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return x @@ -65,15 +64,15 @@ class BatchNormGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, is_training=False, epsilon=1e-5): - self.is_training = validator.check_type('is_training', is_training, (bool,)) - self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT) + self.is_training = validator.check_value_type('is_training', is_training, (bool,), self.name) + self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") - def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) - def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): + def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) @@ -93,19 +92,19 @@ class BinaryCrossEntropyGrad(PrimitiveWithInfer): """Computes gradients for `BinaryCrossEntropy` operation.""" @prim_attr_register def __init__(self, reduction='mean'): - self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum']) + self.reduction = validator.check_string('reduction', reduction, ['none', 'mean', 'sum'], self.name) def infer_shape(self, x_shape, y_shape, doutput_shape, weight_shape): - validator.check_param_equal('x_shape', x_shape, 'y_shape', y_shape) + validator.check('x_shape', x_shape, 'y_shape', y_shape, Rel.EQ, self.name) if weight_shape: - validator.check_param_equal('y_shape', y_shape, 'weight_shape', weight_shape) + validator.check('y_shape', y_shape, 'weight_shape', weight_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_type, y_type, doutput_type, weight_type): args = {'x_type': x_type, 'y_type': y_type, 'doutput_type': doutput_type} - validator.check_type_same(args, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same(args, (mstype.float16, mstype.float32), self.name) if weight_type: - validator.check_two_types_same('x_type', x_type, 'weight_type', weight_type) + validator.check('x_type', x_type, 'weight_type', weight_type, Rel.EQ, TypeError) return x_type @@ -120,7 +119,7 @@ class ConcatOffset(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - offset, _, axis = _get_concat_offset(x_shp, x_type, axis) + offset, _, axis = _get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) offset_values = [] for i in range(len(x_shp)): @@ -250,8 +249,8 @@ class DepthwiseConv2dNativeBackpropFilter(PrimitiveWithInfer): def __infer__(self, x, w_size, dout): w_size_v = w_size['value'] - args = {'x_dtype': x['dtype'], 'dout_type': dout['dtype']} - validator.check_type_same(args, mstype.number_type) + args = {'x': x['dtype'], 'dout': dout['dtype']} + validator.check_tensor_type_same(args, mstype.number_type, self.name) out = { 'value': None, 'shape': w_size_v, @@ -310,8 +309,8 @@ class DepthwiseConv2dNativeBackpropInput(PrimitiveWithInfer): raise NotImplementedError def __infer__(self, x_size, w, dout): - args = {'w_dtype': w['dtype'], 'dout_type': dout['dtype']} - validator.check_type_same(args, mstype.number_type) + args = {'w': w['dtype'], 'dout': dout['dtype']} + validator.check_tensor_type_same(args, mstype.number_type, self.name) x_size_v = x_size['value'] out = { 'value': None, @@ -333,7 +332,7 @@ class FlattenGrad(PrimitiveWithInfer): 'value': None, 'shape': args[1]['value'], 'dtype': args[0]['dtype'], - } + } return out @@ -360,9 +359,9 @@ class GeluGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_backprop_dtype, x_dtype, y_dtype): - validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("y_dtype", y_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"y": y_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -373,56 +372,36 @@ class _PoolGrad(PrimitiveWithInfer): def __init__(self, ksize, strides, padding="VALID"): self.init_prim_io_names(inputs=['x_origin', 'out_origin', 'grad'], outputs=['output']) - validator.check_type('ksize', ksize, [int, tuple]) - validator.check_type('strides', strides, [int, tuple]) - self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME']) + validator.check_value_type('ksize', ksize, [int, tuple], self.name) + validator.check_value_type('strides', strides, [int, tuple], self.name) + self.padding = validator.check_string('padding', padding.upper(), ['VALID', 'SAME'], self.name) self.add_prim_attr("padding", self.padding) self.is_maxpoolgradwithargmax = (self.name == "MaxPoolGradWithArgmax") if not self.is_maxpoolgradwithargmax: self.add_prim_attr('data_format', "NCHW") - if isinstance(ksize, int): - validator.check_integer("ksize", ksize, 1, Rel.GE) - if self.is_maxpoolgradwithargmax: - self.ksize = (1, ksize, ksize, 1) - else: - self.ksize = (1, 1, ksize, ksize) - else: - ksize_error = ValueError(f"The 'ksize' passed to operator {self.name} should be an positive int number" - f"or a tuple of two or four positive int numbers, but got {ksize}") - if len(ksize) != 2 and len(ksize) != 4: - raise ksize_error - for ksize_val in ksize: - if not isinstance(ksize_val, int) or (ksize_val <= 0): - raise ksize_error - if len(ksize) == 2 and self.is_maxpoolgradwithargmax: - self.ksize = (1, ksize[0], ksize[1], 1) - elif len(ksize) == 2 and not self.is_maxpoolgradwithargmax: - self.ksize = (1, 1, ksize[0], ksize[1]) + def _grad_check_int_or_tuple(arg_name, arg_val, is_argmax): + validator.check_value_type(arg_name, arg_val, (int, tuple), self.name) + error_msg = ValueError(f"For '{self.name}' the '{arg_name}' should be an positive int number " + f"or a tuple of two or four positive int numbers, but got {arg_val}") + if isinstance(arg_val, int): + ret = (1, arg_val, arg_val, 1) if is_argmax else (1, 1, arg_val, arg_val) + elif len(arg_val) == 2: + ret = (1, arg_val[0], arg_val[1], 1) if is_argmax else (1, 1, arg_val[0], arg_val[1]) + elif len(arg_val) == 4: + ret = arg_val else: - self.ksize = ksize + raise error_msg + # whether all elements of tuple are positive integers + for item in ret: + if not isinstance(item, int) or item <= 0: + raise error_msg + return ret + + self.ksize = _grad_check_int_or_tuple("ksize", ksize, self.is_maxpoolgradwithargmax) self.add_prim_attr("ksize", self.ksize) - if isinstance(strides, int): - validator.check_integer("strides", strides, 1, Rel.GE) - if self.is_maxpoolgradwithargmax: - self.strides = (1, strides, strides, 1) - else: - self.strides = (1, 1, strides, strides) - else: - strides_error = ValueError(f"The 'strides' passed to operator {self.name} should be an positive int number" - f"or a tuple of two or four positive int numbers, but got {strides}") - if len(strides) != 2 and len(strides) != 4: - raise strides_error - for strides_val in strides: - if not isinstance(strides_val, int) or (strides_val <= 0): - raise strides_error - if len(strides) == 2 and self.is_maxpoolgradwithargmax: - self.strides = (1, strides[0], strides[1], 1) - elif len(strides) == 2 and not self.is_maxpoolgradwithargmax: - self.strides = (1, 1, strides[0], strides[1]) - else: - self.strides = strides + self.strides = _grad_check_int_or_tuple("strides", strides, self.is_maxpoolgradwithargmax) self.add_prim_attr("strides", self.strides) @@ -529,17 +508,17 @@ class L2NormalizeGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=0, epsilon=1e-4): - validator.check_type('axis', axis, [int]) - validator.check_type('epsilon', epsilon, [int, float]) + validator.check_value_type('axis', axis, [int], self.name) + validator.check_value_type('epsilon', epsilon, [int, float], self.name) def infer_shape(self, input_x, out, dout): - validator.check_param_equal('input_x', input_x, 'out', out) - validator.check_param_equal('input_x', input_x, 'dout', dout) + validator.check('input_x shape', input_x, 'out shape', out, Rel.EQ, self.name) + validator.check('input_x shape', input_x, 'dout shape', dout, Rel.EQ, self.name) return input_x def infer_dtype(self, input_x, out, dout): args = {'input_x': input_x, 'out': out, 'dout': dout} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return input_x @@ -560,8 +539,8 @@ class LayerNormGrad(Primitive): @prim_attr_register def __init__(self, begin_norm_axis=1, begin_params_axis=1): """init""" - self.begin_norm_axis = validator.check_type('begin_norm_axis', begin_norm_axis, [int]) - self.begin_params_axis = validator.check_type('begin_params_axis', begin_params_axis, [int]) + self.begin_norm_axis = validator.check_value_type('begin_norm_axis', begin_norm_axis, [int], self.name) + self.begin_params_axis = validator.check_value_type('begin_params_axis', begin_params_axis, [int], self.name) def __call__(self, x, dy, variance, mean, gamma): raise NotImplementedError @@ -573,15 +552,15 @@ class LogSoftmaxGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, axis=-1): """init LogSoftmaxGrad""" - validator.check_type("axis", axis, [int]) + validator.check_value_type("axis", axis, [int], self.name) def infer_shape(self, dout, logits): rank = len(logits) - validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH) + validator.check_int_range('axis', self.axis, -rank - 1, rank, Rel.INC_BOTH, self.name) return logits def infer_dtype(self, dout, logits): - validator.check_subclass("logits", logits, mstype.tensor) + validator.check_subclass("logits", logits, mstype.tensor, self.name) return logits @@ -590,13 +569,13 @@ class LSTMGradData(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -606,19 +585,19 @@ class LSTMGradData(PrimitiveWithInfer): def infer_shape(self, y_shape, dy_shape, dhy_shape, dcy_shape, w_shape, hx_shape, cx_shape, reserve_shape, state_shape): # dhy and dcy should be same shape - validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ) - validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ) - validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ) - validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ) - validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ) + validator.check_integer("h_shape", len(dhy_shape), 3, Rel.EQ, self.name) + validator.check_integer("h_shape", len(dhy_shape), len(dcy_shape), Rel.EQ, self.name) + validator.check_integer("h_shape[0]", dhy_shape[0], dcy_shape[0], Rel.EQ, self.name) + validator.check_integer("h_shape[1]", dhy_shape[1], dcy_shape[1], Rel.EQ, self.name) + validator.check_integer("h_shape[2]", dhy_shape[2], dcy_shape[2], Rel.EQ, self.name) - validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ) - validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ) + validator.check_integer("h_shape[0]", dhy_shape[0], self.num_layers * self.num_directions, Rel.EQ, self.name) + validator.check_integer("h_shape[2]", dhy_shape[2], self.hidden_size, Rel.EQ, self.name) # dy: (seq_len, batch_size, hidden_size * num_directions) - validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ) - validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ) - validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ) + validator.check_integer("dy_shape", len(dy_shape), 3, Rel.EQ, self.name) + validator.check_integer("dy[1]", dy_shape[1], dhy_shape[1], Rel.EQ, self.name) + validator.check_integer("dy[2]", dy_shape[2], self.hidden_size * self.num_directions, Rel.EQ, self.name) # (seq_len, batch_size, input_size) dx_shape = (y_shape[0], y_shape[1], self.input_size) @@ -629,11 +608,8 @@ class LSTMGradData(PrimitiveWithInfer): def infer_dtype(self, y_dtype, dy_dtype, dhy_dtype, dcy_dtype, w_dtype, hx_dtype, cx_dtype, reserve_dtype, state_dtype): - validator.check_typename("dy_dtype", dy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("dhy_dtype", dhy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("dcy_dtype", dcy_dtype, (mstype.float32, mstype.float16)) - validator.check_typename("datatype", dy_dtype, (dhy_dtype.element_type(),)) - validator.check_typename("datatype", dy_dtype, (dcy_dtype.element_type(),)) + args = {"dy": dy_dtype, "dhy": dhy_dtype, "dcy": dcy_dtype} + validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) return (dy_dtype, dy_dtype, dy_dtype) @@ -642,13 +618,13 @@ class LSTMGradWeight(PrimitiveWithInfer): @prim_attr_register def __init__(self, input_size, hidden_size, num_layers, has_bias, bidirectional, dropout): - self.input_size = check_int_positive(input_size) - self.hidden_size = check_int_positive(hidden_size) - self.num_layers = check_int_positive(num_layers) - self.has_bias = check_bool(has_bias) - self.bidirectional = check_bool(bidirectional) - self.dropout = validator.check_type("dropout", dropout, [float]) - self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH) + self.input_size = validator.check_integer('input_size', input_size, 0, Rel.GT, self.name) + self.hidden_size = validator.check_integer('hidden_size', hidden_size, 0, Rel.GT, self.name) + self.num_layers = validator.check_integer('num_layers', num_layers, 0, Rel.GT, self.name) + self.has_bias = validator.check_value_type('has_bias', has_bias, (bool,), self.name) + self.bidirectional = validator.check_value_type('bidirectional', bidirectional, (bool,), self.name) + self.dropout = validator.check_value_type("dropout", dropout, [float], self.name) + self.dropout = validator.check_number_range('dropout', dropout, 0, 1, Rel.INC_BOTH, self.name) if bidirectional: self.num_directions = 2 @@ -693,9 +669,10 @@ class PReLUGrad(PrimitiveWithInfer): return y_backprop_shape, w_shape def infer_dtype(self, y_backprop_dtype, A_dtype, w_dtype): - validator.check_typename("y_backprop_dtype", y_backprop_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("A_dtype", A_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("w_dtype", w_dtype, (mstype.float16, mstype.float32)) + valid_types = (mstype.float16, mstype.float32) + validator.check_tensor_type_same({"y_backprop": y_backprop_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"A_dtype": A_dtype}, valid_types, self.name) + validator.check_tensor_type_same({"w_dtype": w_dtype}, valid_types, self.name) return y_backprop_dtype, w_dtype @@ -725,8 +702,8 @@ class ReLU6Grad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad_dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -744,10 +721,8 @@ class ReluGradV2(PrimitiveWithInfer): return gradients_shape def infer_dtype(self, gradients_dtype, mask_dtype): - args_type = {'gradients': gradients_dtype, 'mask': mask_dtype} - validator.check_args_tensor(args_type) - validator.check_typename("gradients_dtype", gradients_dtype, mstype.number_type) - validator.check_typename("mask_dtype", mask_dtype, (mstype.uint8,)) + validator.check_tensor_type_same({'gradients': gradients_dtype}, mstype.number_type, self.name) + validator.check_tensor_type_same({'mask': mask_dtype}, (mstype.uint8,), self.name) return gradients_dtype @@ -762,10 +737,8 @@ class EluGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - args_type = {'y_grad': y_grad_dtype, 'x': x_dtype} - validator.check_args_tensor(args_type) - args_dtype = {'y_grad_dtype': y_grad_dtype, 'x_dtype': x_dtype} - validator.check_type_same(args_dtype, mstype.float_type) + args = {'y_grad': y_grad_dtype, 'x': x_dtype} + validator.check_tensor_type_same(args, mstype.float_type, self.name) return x_dtype @@ -821,11 +794,11 @@ class ROIAlignGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, xdiff_shape, pooled_height, pooled_width, spatial_scale, sample_num=2): """init ROIAlignGrad""" - validator.check_type("pooled_height", pooled_height, [int]) - validator.check_type("pooled_width", pooled_width, [int]) - validator.check_type("spatial_scale", spatial_scale, [float]) - validator.check_type("sample_num", sample_num, [int]) - validator.check_type("xdiff_shape", xdiff_shape, [tuple]) + validator.check_value_type("pooled_height", pooled_height, [int], self.name) + validator.check_value_type("pooled_width", pooled_width, [int], self.name) + validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) + validator.check_value_type("sample_num", sample_num, [int], self.name) + validator.check_value_type("xdiff_shape", xdiff_shape, [tuple], self.name) self.xdiff_shape = xdiff_shape self.pooled_height = pooled_height self.pooled_width = pooled_width @@ -850,10 +823,8 @@ class SigmoidGrad(PrimitiveWithInfer): return out def infer_dtype(self, out, dout): - validator.check_typename("dout dtype", dout, (mstype.float16, mstype.float32)) - validator.check_typename("out dtype", out, (mstype.float16, mstype.float32)) - args = {"out type": out, "dout type": dout} - validator.check_type_same(args, mstype.number_type) + args = {'out': out, 'dout': dout} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return out @@ -868,8 +839,8 @@ class HSigmoidGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -884,8 +855,8 @@ class HSwishGrad(PrimitiveWithInfer): return x_shape def infer_dtype(self, y_grad_dtype, x_dtype): - validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) - validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32)) + validator.check_tensor_type_same({"y_grad": y_grad_dtype}, (mstype.float16, mstype.float32), self.name) + validator.check_tensor_type_same({"x": x_dtype}, (mstype.float16, mstype.float32), self.name) return x_dtype @@ -898,13 +869,13 @@ class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): self.init_prim_io_names(inputs=['x', 'y', 'dout'], outputs=['x_grad']) def infer_shape(self, x_shape, y_shape, dout_shape): - validator.check_param_equal("x_shape", x_shape, "y_shape", y_shape) - validator.check_param_equal("x_shape", x_shape, "dout_shape", dout_shape) + validator.check("x_shape", x_shape, "y_shape", y_shape, Rel.EQ, self.name) + validator.check("x_shape", x_shape, "dout_shape", dout_shape, Rel.EQ, self.name) return x_shape def infer_dtype(self, x_dtype, y_dtype, dout_dtype): args = {"x_dtype": x_dtype, "y_dtype": y_dtype, 'dout_dtype': dout_dtype} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return dout_dtype @@ -920,8 +891,8 @@ class SliceGrad(PrimitiveWithInfer): dy_shape, x_shape, size_value = dy['shape'], x['shape'], size['value'] dy_shape_len = len(dy_shape) for i in range(dy_shape_len): - validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE) - validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ) + validator.check(f'dy_shape[{i}]', dy_shape[i], f'x_shape[{i}]', x_shape[i], Rel.LE, self.name) + validator.check(f'dy_shape[{i}]', dy_shape[i], f'size_shape[{i}]', size_value[i], Rel.EQ, self.name) return {'shape': x_shape, 'dtype': x['dtype'], 'value': None} @@ -935,13 +906,13 @@ class SmoothL1LossGrad(PrimitiveWithInfer): pass def infer_shape(self, prediction, target, dloss): - validator.check_param_equal('prediction', prediction, 'target', target) - validator.check_param_equal('prediction', prediction, 'dloss', dloss) + validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) + validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name) return prediction def infer_dtype(self, prediction, target, dloss): args = {"prediction": prediction, "target": target, 'dloss': dloss} - validator.check_type_same(args, mstype.number_type) + validator.check_tensor_type_same(args, mstype.number_type, self.name) return dloss @@ -968,11 +939,11 @@ class StridedSliceGrad(PrimitiveWithInfer): new_axis_mask=0, shrink_axis_mask=0): """init StrideSliceGrad""" - validator.check_type('begin_mask', begin_mask, [int]) - validator.check_type('end_mask', end_mask, [int]) - validator.check_type('ellipsis_mask', ellipsis_mask, [int]) - validator.check_type('new_axis_mask', new_axis_mask, [int]) - validator.check_type('shrink_axis_mask', shrink_axis_mask, [int]) + validator.check_value_type('begin_mask', begin_mask, [int], self.name) + validator.check_value_type('end_mask', end_mask, [int], self.name) + validator.check_value_type('ellipsis_mask', ellipsis_mask, [int], self.name) + validator.check_value_type('new_axis_mask', new_axis_mask, [int], self.name) + validator.check_value_type('shrink_axis_mask', shrink_axis_mask, [int], self.name) self.init_prim_io_names(inputs=['dy', 'shapex', 'begin', 'end', 'strides'], outputs=['output']) def __infer__(self, dy, shapex, begin, end, strides): @@ -992,10 +963,8 @@ class TanhGrad(PrimitiveWithInfer): return out def infer_dtype(self, out, dout): - validator.check_subclass("out", out, mstype.tensor) - validator.check_subclass("dout", dout, mstype.tensor) - args = {"out type": out, "dout type": dout} - validator.check_type_same(args, mstype.number_type) + args = {"out": out, "dout": dout} + validator.check_tensor_type_same(args, mstype.number_type, self.name) return out @@ -1005,13 +974,13 @@ class MirrorPadGrad(PrimitiveWithInfer): @prim_attr_register def __init__(self, mode="REFLECT"): """init MirrorPad""" - validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC']) + validator.check_string('mode', mode, ['REFLECT', 'SYMMETRIC'], self.name) self.mode = mode def __infer__(self, dout, paddings, x): - validator.check_subclass("dout", dout['dtype'], mstype.tensor) - validator.check_subclass("paddings", paddings['dtype'], mstype.tensor) - validator.check_subclass("input_x", x['dtype'], mstype.tensor) + validator.check_subclass("dout", dout['dtype'], mstype.tensor, self.name) + validator.check_subclass("paddings", paddings['dtype'], mstype.tensor, self.name) + validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) return {'shape': x['shape'], 'dtype': dout['dtype'], 'value': None} From 2f89b75b2d5bc0485910ffdd19ef7bc613172b88 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 15:02:24 +0800 Subject: [PATCH 133/142] remove batchnorm grad 6 output to 5 output --- mindspore/ops/operations/_grad_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 07857ca27b..782784ca00 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -68,11 +68,11 @@ class BatchNormGrad(PrimitiveWithInfer): self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, self.name) self.add_prim_attr('data_format', "NCHW") - def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape, reserve_3_shape): + def infer_shape(self, y_backprop_shape, x_shape, scale_shape, reserve_1_shape, reserve_2_shape): validator.check("BatchNorm y_backprop_shape", y_backprop_shape, "BatchNorm x_shape", x_shape) return (x_shape, scale_shape, scale_shape, reserve_1_shape, reserve_2_shape) - def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type, reserve_3_type): + def infer_dtype(self, y_backprop_type, x_type, scale_type, reserve_1_type, reserve_2_type): return (x_type, scale_type, scale_type, reserve_1_type, reserve_2_type) From 4c37420890e3a1efca34838a8c8fb15fca738ffa Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 15:13:21 +0800 Subject: [PATCH 134/142] make conv2d bp filter stride attr len 4 --- mindspore/ops/operations/_grad_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index 782784ca00..c821063da8 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -174,9 +174,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) self.pad = pad - if isinstance(stride, tuple) and len(stride) == 4: - self.stride = (stride[2], stride[3]) - self.add_prim_attr('stride', self.stride) + if isinstance(stride, tuple) and len(stride) == 2: + self.stride = stride + self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) self.dilation = dilation self.group = group self.add_prim_attr('data_format', "NCHW") From c547f434a1be28c8487b951b0da6b5c65fc9576b Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 15:22:58 +0800 Subject: [PATCH 135/142] conv2d bp input stride len 2 to len 4 --- mindspore/ops/operations/nn_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 9750549dc5..78a512fab7 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1084,7 +1084,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) self.add_prim_attr('stride', self.stride) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) From 0e98430f4d5b4b06cd19b73fa9a1abfc0253ede2 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 15:55:58 +0800 Subject: [PATCH 136/142] get concat offset refactor --- mindspore/ops/_utils/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mindspore/ops/_utils/utils.py b/mindspore/ops/_utils/utils.py index fbd81c4f0d..90496afc9b 100644 --- a/mindspore/ops/_utils/utils.py +++ b/mindspore/ops/_utils/utils.py @@ -15,7 +15,7 @@ """utils for operator""" -from ..._checkparam import ParamValidator as validator +from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype @@ -62,25 +62,25 @@ def _get_broadcast_shape(x_shape, y_shape, prim_name): return broadcast_shape -def _get_concat_offset(x_shp, x_type, axis): +def _get_concat_offset(x_shp, x_type, axis, prim_name): """for concat and concatoffset check args and compute offset""" - validator.check_type("shape", x_shp, [tuple]) - validator.check_integer("len of input_x shape", len(x_shp), 0, Rel.GT) - validator.check_subclass("shape0", x_type[0], mstype.tensor) - validator.check_integer("len of input_x0 shape", len(x_shp[0]), 0, Rel.GT) + validator.check_value_type("shape", x_shp, [tuple], prim_name) + validator.check_integer("input_x rank", len(x_shp), 0, Rel.GT, prim_name) + validator.check_subclass("shape0", x_type[0], mstype.tensor, prim_name) + validator.check_integer("len of x_shp[0]", len(x_shp[0]), 0, Rel.GT, prim_name) rank_base = len(x_shp[0]) - validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH) + validator.check_int_range('axis', axis, -rank_base - 1, rank_base, Rel.INC_BOTH, prim_name) if axis < 0: axis = axis + rank_base all_shp = x_shp[0][axis] offset = [0,] for i in range(1, len(x_shp)): v = x_shp[i] - validator.check('len of x_shp[%d]' % i, len(v), 'len of base', len(x_shp[0])) - validator.check('x_type[%d]' % i, x_type[i], 'base', x_type[0]) + validator.check('len of x_shp[%d]' % i, len(v), 'len of x_shp[0]', len(x_shp[0]), Rel.EQ, prim_name) + validator.check('x_type[%d]' % i, x_type[i], 'x_type[0]', x_type[0], Rel.EQ, prim_name) for j in range(rank_base): if j != axis and v[j] != x_shp[0][j]: - raise ValueError("Concat evaluator element %d shape in input can not concat with first element" % i) + raise ValueError(f"For \'{prim_name}\' element {i} shape in input can not concat with first element") offset.append(all_shp) all_shp += v[axis] return offset, all_shp, axis From cbb4136b626ef048c6c17e38890d7018896e4ed2 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 16:02:38 +0800 Subject: [PATCH 137/142] concat offset update --- mindspore/ops/operations/array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 46239855f2..21dbf81730 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1316,7 +1316,7 @@ class Concat(PrimitiveWithInfer): axis = self.axis x_shp = input_x['shape'] x_type = input_x['dtype'] - _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis) + _, all_shp, _ = _get_concat_offset(x_shp, x_type, axis, self.name) self.add_prim_attr('T', x_type[0].element_type()) self.add_prim_attr('inputNums', len(x_shp)) ret_shp = x_shp[0].copy() From 3209687b9a1c569177027ce4db9384d5249dc5f7 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 17:07:29 +0800 Subject: [PATCH 138/142] conv2d bp filter input primitive stride 2 and adapt to 4 --- mindspore/ccsrc/transform/op_adapter.h | 2 +- mindspore/ccsrc/transform/op_adapter_util.cc | 14 ++++++++++---- mindspore/ccsrc/transform/op_declare.cc | 4 ++-- mindspore/ops/operations/_grad_ops.py | 6 +++--- mindspore/ops/operations/nn_ops.py | 2 +- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index 2039dfa7d6..ae678606a4 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -736,7 +736,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(GetValue(value)); } - // specialization for int to Vector + // specialization for int or tuple broadcast to Vector static std::vector ConvertAny(const ValuePtr &value, const std::string &name, const AnyTraits> anyTraitsInt) { return ConvertAnyUtil(value, name, anyTraitsInt); diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index 0163b80f08..0d9e56e510 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -35,14 +35,20 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); int64_t data = GetValue(value); std::vector list; int size = 2; // 2 int in list if (name == "pad") { - size = 4; // 4 int in list - list = TransformUtil::ConvertIntToList(data, size); - list[0] = 1; - list[1] = 1; + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); + } + auto vec = value->cast(); + list.push_back(1); + list.push_back(1); + for (auto &it : vec->value()) { + list.push_back(static_cast(GetValue(it))); + } } else { list = TransformUtil::ConvertIntToList(data, size); } diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 377403cc89..5ec54b2037 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -733,7 +733,7 @@ INPUT_ATTR_MAP(Conv2DBackpropInputD) = { {3, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(Conv2DBackpropInputD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, {"group", ATTR_DESC(groups, AnyTraits())}, @@ -746,7 +746,7 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { {3, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(Conv2DBackpropFilterD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, {"group", ATTR_DESC(groups, AnyTraits())}, diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index c821063da8..782784ca00 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -174,9 +174,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) self.pad = pad - if isinstance(stride, tuple) and len(stride) == 2: - self.stride = stride - self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) + if isinstance(stride, tuple) and len(stride) == 4: + self.stride = (stride[2], stride[3]) + self.add_prim_attr('stride', self.stride) self.dilation = dilation self.group = group self.add_prim_attr('data_format', "NCHW") diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 78a512fab7..9750549dc5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1084,7 +1084,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.add_prim_attr('stride', self.stride) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation) From ba4a34853103b81bef096a0956f4f74ce3e3429b Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 17:35:17 +0800 Subject: [PATCH 139/142] fix Cppcheck --- mindspore/ccsrc/transform/op_adapter_util.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index 0d9e56e510..07266c9eb2 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -38,18 +38,18 @@ std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &na MS_EXCEPTION_IF_NULL(value); int64_t data = GetValue(value); std::vector list; - int size = 2; // 2 int in list if (name == "pad") { if (!value->isa()) { MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); } auto vec = value->cast(); - list.push_back(1); - list.push_back(1); - for (auto &it : vec->value()) { - list.push_back(static_cast(GetValue(it))); - } + list.resize(vec->value().size()+2); + list[0]=1; + list[1]=1; + (void)std::transform(vec->value().begin(), vec->value().end(), list.begin()+2, + [](const ValuePtr &val) { return static_cast(GetValue(val)); }); } else { + int size = 2; // 2 int in list list = TransformUtil::ConvertIntToList(data, size); } From 17de1b789012c83576f33b409aea33a49781d558 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 17:40:24 +0800 Subject: [PATCH 140/142] fix Cppcheck --- mindspore/ccsrc/transform/op_adapter_util.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index 07266c9eb2..2f5f3ddc7e 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -44,8 +44,8 @@ std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &na } auto vec = value->cast(); list.resize(vec->value().size()+2); - list[0]=1; - list[1]=1; + list[0] = 1; + list[1] = 1; (void)std::transform(vec->value().begin(), vec->value().end(), list.begin()+2, [](const ValuePtr &val) { return static_cast(GetValue(val)); }); } else { From de6ca6bd2f4c28e5f875f0e72f88afa9c2b2e4ed Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 17:45:40 +0800 Subject: [PATCH 141/142] fix convert value tuple --- mindspore/ccsrc/transform/op_adapter_util.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index 2f5f3ddc7e..203acac10f 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -36,7 +36,6 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>) { MS_EXCEPTION_IF_NULL(value); - int64_t data = GetValue(value); std::vector list; if (name == "pad") { if (!value->isa()) { @@ -49,6 +48,7 @@ std::vector ConvertAnyUtil(const ValuePtr &value, const std::string &na (void)std::transform(vec->value().begin(), vec->value().end(), list.begin()+2, [](const ValuePtr &val) { return static_cast(GetValue(val)); }); } else { + int64_t data = GetValue(value); int size = 2; // 2 int in list list = TransformUtil::ConvertIntToList(data, size); } From 89f82e1a3799b9474f1901f99ccd0affeb1841aa Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 19:04:40 +0800 Subject: [PATCH 142/142] fix cpp ut test TestConvertConvBackpropFilter --- tests/ut/cpp/transform/convert_test.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/ut/cpp/transform/convert_test.cc b/tests/ut/cpp/transform/convert_test.cc index 4388312592..277aaa15c3 100644 --- a/tests/ut/cpp/transform/convert_test.cc +++ b/tests/ut/cpp/transform/convert_test.cc @@ -189,7 +189,8 @@ TEST_F(TestConvert, TestConvertBatchNorm) { TEST_F(TestConvert, TestConvertConvBackpropInput) { auto prim = prim::kPrimConv2DBackpropInput; - prim->AddAttr("stride", MakeValue(1)); + const std::vector list{1,1}; + prim->AddAttr("stride", MakeValue(list)); prim->AddAttr("pad", MakeValue(0)); prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); prim->AddAttr("dilation", MakeValue(1)); @@ -218,7 +219,8 @@ TEST_F(TestConvert, TestConvertConvBackpropInput) { TEST_F(TestConvert, TestConvertConvBackpropFilter) { auto prim = prim::kPrimConv2DBackpropFilter; - prim->AddAttr("stride", MakeValue(1)); + const std::vector list{1,1}; + prim->AddAttr("stride", MakeValue(list)); prim->AddAttr("pad", MakeValue(0)); prim->AddAttr("pad_mode", MakeValue(std::string("pad"))); prim->AddAttr("dilation", MakeValue(1));