Merge pull request !1402 from zongha/mastertags/v0.3.0-alpha
| @@ -1,126 +0,0 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """learning rate generator""" | |||||
| import math | |||||
| import numpy as np | |||||
| def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): | |||||
| """linear_warmup_lr""" | |||||
| lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) | |||||
| lr = float(init_lr) + lr_inc * current_step | |||||
| return lr | |||||
| def cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5): | |||||
| """linear_warmup_lr""" | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| decay_steps = total_steps - warmup_steps | |||||
| lr_each_step = [] | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| # linear_decay = (total_steps - i) / decay_steps | |||||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * i / decay_steps)) | |||||
| decayed = cosine_decay | |||||
| lr = base_lr * decayed | |||||
| lr_each_step.append(lr) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0, num_periods=0.5): | |||||
| """warmup_cosine_annealing_lr""" | |||||
| base_lr = lr | |||||
| warmup_init_lr = 0 | |||||
| total_steps = int(max_epoch * steps_per_epoch * 0.99) | |||||
| warmup_steps = int(warmup_epochs * steps_per_epoch) | |||||
| decay_steps = total_steps - warmup_steps | |||||
| lr_each_step = [] | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) | |||||
| else: | |||||
| linear_decay = (total_steps - i) / decay_steps | |||||
| cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * num_periods * i / decay_steps)) | |||||
| decayed = linear_decay * cosine_decay | |||||
| lr = base_lr * decayed + 0.000005 | |||||
| lr_each_step.append(lr) | |||||
| return np.array(lr_each_step).astype(np.float32) | |||||
| def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode): | |||||
| """ | |||||
| generate learning rate array | |||||
| Args: | |||||
| global_step(int): total steps of the training | |||||
| lr_init(float): init learning rate | |||||
| lr_end(float): end learning rate | |||||
| lr_max(float): max learning rate | |||||
| warmup_epochs(int): number of warmup epochs | |||||
| total_epochs(int): total epoch of training | |||||
| steps_per_epoch(int): steps of one epoch | |||||
| lr_decay_mode(string): learning rate decay mode, including steps, poly or default | |||||
| Returns: | |||||
| np.array, learning rate array | |||||
| """ | |||||
| lr_each_step = [] | |||||
| total_steps = steps_per_epoch * total_epochs | |||||
| warmup_steps = steps_per_epoch * warmup_epochs | |||||
| if lr_decay_mode == 'steps': | |||||
| decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps] | |||||
| for i in range(total_steps): | |||||
| if i < decay_epoch_index[0]: | |||||
| lr = lr_max | |||||
| elif i < decay_epoch_index[1]: | |||||
| lr = lr_max * 0.1 | |||||
| elif i < decay_epoch_index[2]: | |||||
| lr = lr_max * 0.01 | |||||
| else: | |||||
| lr = lr_max * 0.001 | |||||
| lr_each_step.append(lr) | |||||
| elif lr_decay_mode == 'poly': | |||||
| if warmup_steps != 0: | |||||
| inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps) | |||||
| else: | |||||
| inc_each_step = 0 | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = float(lr_init) + inc_each_step * float(i) | |||||
| else: | |||||
| base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps))) | |||||
| lr = float(lr_max) * base * base | |||||
| if lr < 0.0: | |||||
| lr = 0.0 | |||||
| lr_each_step.append(lr) | |||||
| else: | |||||
| for i in range(total_steps): | |||||
| if i < warmup_steps: | |||||
| lr = lr_init + (lr_max - lr_init) * i / warmup_steps | |||||
| else: | |||||
| lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps) | |||||
| lr_each_step.append(lr) | |||||
| current_step = global_step | |||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | |||||
| learning_rate = lr_each_step[current_step:] | |||||
| return learning_rate | |||||
| @@ -13,12 +13,10 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Dataset help for minddata dataset""" | """Dataset help for minddata dataset""" | ||||
| from mindspore import context | |||||
| from mindspore._checkparam import check_bool | from mindspore._checkparam import check_bool | ||||
| from mindspore.nn.wrap import GetNextSingleOp | |||||
| from mindspore.parallel._utils import _get_device_num, _get_global_rank, _get_parallel_mode | |||||
| from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, _to_tensor, \ | |||||
| _construct_tensor_list, _to_full_shapes, _to_full_tensor | |||||
| from mindspore.parallel._utils import _get_device_num, _get_parallel_mode | |||||
| from mindspore.train._utils import _exec_datagraph, _get_types_and_shapes, \ | |||||
| _to_full_shapes | |||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| @@ -42,19 +40,9 @@ class DatasetHelper: | |||||
| >>> outputs = network(*inputs) | >>> outputs = network(*inputs) | ||||
| """ | """ | ||||
| def __init__(self, dataset, first_order_iter=0, dataset_sink_mode=True): | |||||
| def __init__(self, dataset, dataset_sink_mode=True, iter_first_order=0): | |||||
| check_bool(dataset_sink_mode) | check_bool(dataset_sink_mode) | ||||
| iterclass = _DatasetIterGE | |||||
| if not dataset_sink_mode: | |||||
| iterclass = _DatasetIterFeed | |||||
| elif not context.get_context("enable_ge"): | |||||
| if context.get_context("enable_loop_sink"): | |||||
| iterclass = _DatasetIterMSLoopSink | |||||
| else: | |||||
| iterclass = _DatasetIterMS | |||||
| self.iter = iterclass(dataset, first_order_iter) | |||||
| self.iter = _DatasetIterMSLoopSink(dataset, iter_first_order) | |||||
| def __iter__(self): | def __iter__(self): | ||||
| return self.iter.__iter__() | return self.iter.__iter__() | ||||
| @@ -85,12 +73,6 @@ class _DatasetIter: | |||||
| self.dataset = dataset | self.dataset = dataset | ||||
| dataset_types, dataset_shapes = _get_types_and_shapes(dataset) | dataset_types, dataset_shapes = _get_types_and_shapes(dataset) | ||||
| self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes | self.dataset_types, self.dataset_shapes = dataset_types, dataset_shapes | ||||
| # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to | |||||
| # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number | |||||
| # times the batch dimension of tensors for run | |||||
| if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| device_num = _get_device_num() | |||||
| self.dataset_shapes = _to_full_shapes(dataset_shapes, device_num) | |||||
| def __iter__(self): | def __iter__(self): | ||||
| self.ind = 0 | self.ind = 0 | ||||
| @@ -109,83 +91,28 @@ class _DatasetIter: | |||||
| loop_count = 1 | loop_count = 1 | ||||
| if hasattr(dataset, '__loop_size__'): | if hasattr(dataset, '__loop_size__'): | ||||
| loop_size = dataset.__loop_size__ | loop_size = dataset.__loop_size__ | ||||
| if dataset.get_dataset_size() % loop_size != 0: | |||||
| raise ValueError(f'Dataset size {dataset.get_dataset_size()} and ' | |||||
| f'loop_size {loop_size} are not matched.') | |||||
| loop_count = int(dataset.get_dataset_size() / loop_size) | loop_count = int(dataset.get_dataset_size() / loop_size) | ||||
| return loop_count | return loop_count | ||||
| class _DatasetIterMSLoopSink(_DatasetIter): | class _DatasetIterMSLoopSink(_DatasetIter): | ||||
| """Iter for context (enable_loop_sink=True)""" | |||||
| """Iter for context (device_target=Ascend)""" | |||||
| def __init__(self, dataset, first_order_iter): | |||||
| def __init__(self, dataset, iter_first_order): | |||||
| super(_DatasetIterMSLoopSink, self).__init__(dataset) | super(_DatasetIterMSLoopSink, self).__init__(dataset) | ||||
| # self.loop_count = self.get_loop_count(dataset) | |||||
| loop_size = dataset.__loop_size__ + first_order_iter | |||||
| loop_size = dataset.__loop_size__ + iter_first_order | |||||
| self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2 | self.loop_count = int(dataset.get_dataset_size() / loop_size) * 2 | ||||
| # for self._parallel_mode equal to semi_auto_parallel or auto_parallel, use a complete tensor to | |||||
| # compile, and slice tensor to run. The batch dimension of tensors for compile is device_number | |||||
| # times the batch dimension of tensors for run. Now only support LoopSink. | |||||
| if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| device_num = _get_device_num() | |||||
| self.dataset_shapes = _to_full_shapes(self.dataset_shapes, device_num) | |||||
| def op(): | def op(): | ||||
| return tuple() | return tuple() | ||||
| self.op = op | self.op = op | ||||
| class _DatasetIterMS(_DatasetIter): | |||||
| """Iter for context (enable_loop_sink=False)""" | |||||
| def __init__(self, dataset, first_order_order): | |||||
| super(_DatasetIterMS, self).__init__(dataset) | |||||
| self.loop_count = dataset.get_dataset_size() | |||||
| self.loop_size = 1 | |||||
| queue_name = dataset.__ME_INITED__ | |||||
| self.op = GetNextSingleOp(self.dataset_types, self.dataset_shapes, queue_name) | |||||
| class _DatasetIterGE(_DatasetIter): | |||||
| """Iter for ge""" | |||||
| def __init__(self, dataset): | |||||
| super(_DatasetIterGE, self).__init__(dataset) | |||||
| self.loop_count = self.get_loop_count(dataset) | |||||
| parallel_mode = _get_parallel_mode() | |||||
| self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| batch_expand_num = 1 | |||||
| if self.need_to_full: | |||||
| batch_expand_num = _get_device_num() | |||||
| tensor_list_run = _construct_tensor_list(self.dataset_types, self.dataset_shapes, batch_expand_num) | |||||
| def op(): | |||||
| return tensor_list_run | |||||
| self.op = op | |||||
| class _DatasetIterFeed: | |||||
| """Iter for feed data""" | |||||
| def __init__(self, dataset, first_order_order): | |||||
| self.dataset = dataset | |||||
| self.device_num = _get_device_num() | |||||
| self.global_rank = _get_global_rank() | |||||
| self.repeat_count = dataset.get_repeat_count() | |||||
| self.repeat_ind = 0 | |||||
| self.loop_count = dataset.get_dataset_size() | |||||
| self.ind = 0 | |||||
| parallel_mode = context.get_auto_parallel_context("parallel_mode") | |||||
| self.need_to_full = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | |||||
| def __iter__(self): | |||||
| if self.repeat_ind % self.repeat_count == 0: | |||||
| self.iter = self.dataset.__iter__() | |||||
| self.repeat_ind += 1 | |||||
| self.ind = 0 | |||||
| return self | |||||
| def __next__(self): | |||||
| if self.ind >= self.loop_count: | |||||
| raise StopIteration() | |||||
| self.ind += 1 | |||||
| data = self.iter.__next__() | |||||
| if self.need_to_full: | |||||
| return _to_full_tensor(data, self.device_num, self.global_rank) | |||||
| return _to_tensor(data) | |||||
| @@ -13,8 +13,11 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| """Model.""" | """Model.""" | ||||
| import numpy as np | |||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from mindspore import nn | |||||
| from mindspore._c_expression import init_exec_dataset | from mindspore._c_expression import init_exec_dataset | ||||
| from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool | from mindspore._checkparam import check_input_data, check_output_data, check_int_positive, check_bool | ||||
| from mindspore.common import dtype as mstype | from mindspore.common import dtype as mstype | ||||
| @@ -28,9 +31,9 @@ from mindspore.parallel._utils import _get_parallel_mode, _get_device_num, _get_ | |||||
| from mindspore.train import amp | from mindspore.train import amp | ||||
| from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks | from mindspore.train.callback import _InternalCallbackParam, RunContext, _build_callbacks | ||||
| from mindspore.train.parallel_utils import ParallelMode | from mindspore.train.parallel_utils import ParallelMode | ||||
| import mindspore.nn as nn | |||||
| from second_order.dataset_helper import DatasetHelper | |||||
| import numpy as np | |||||
| from model.dataset_helper import DatasetHelper | |||||
| def _convert_type(types): | def _convert_type(types): | ||||
| """ | """ | ||||
| @@ -69,7 +72,8 @@ def _exec_datagraph(exec_dataset, dataset_size, phase='dataset'): | |||||
| dataset_types, | dataset_types, | ||||
| dataset_shapes, | dataset_shapes, | ||||
| input_indexs, | input_indexs, | ||||
| phase=phase) | |||||
| phase=phase, | |||||
| need_run=False) | |||||
| class Model: | class Model: | ||||
| @@ -123,7 +127,7 @@ class Model: | |||||
| >>> return out | >>> return out | ||||
| >>> | >>> | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None) | ||||
| >>> dataset = get_dataset() | >>> dataset = get_dataset() | ||||
| @@ -131,30 +135,36 @@ class Model: | |||||
| """ | """ | ||||
| def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, | def __init__(self, network, loss_fn=None, optimizer=None, metrics=None, eval_network=None, | ||||
| eval_indexes=None, amp_level="O0", frequency=278, **kwargs): | |||||
| eval_indexes=None, amp_level="O0", frequency=278, stop_epoch=100, **kwargs): | |||||
| self._network = network | self._network = network | ||||
| self._loss_fn = loss_fn | self._loss_fn = loss_fn | ||||
| self._optimizer = optimizer | self._optimizer = optimizer | ||||
| self._loss_scale_manager = None | self._loss_scale_manager = None | ||||
| self._loss_scale_manager_set = False | self._loss_scale_manager_set = False | ||||
| self._keep_bn_fp32 = True | self._keep_bn_fp32 = True | ||||
| self._frequency = frequency | |||||
| self._check_kwargs(kwargs) | self._check_kwargs(kwargs) | ||||
| if 'keep_batchnorm_fp32' in kwargs: | |||||
| self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] | |||||
| if 'loss_scale_manager' in kwargs: | |||||
| self._loss_scale_manager = kwargs['loss_scale_manager'] | |||||
| self._loss_scale_manager_set = True | |||||
| self._amp_level = amp_level | self._amp_level = amp_level | ||||
| self._process_amp_args(kwargs) | |||||
| self._parallel_mode = _get_parallel_mode() | self._parallel_mode = _get_parallel_mode() | ||||
| self._device_number = _get_device_num() | self._device_number = _get_device_num() | ||||
| self._global_rank = _get_global_rank() | self._global_rank = _get_global_rank() | ||||
| self._parameter_broadcast = _get_parameter_broadcast() | self._parameter_broadcast = _get_parameter_broadcast() | ||||
| self._frequency = frequency | |||||
| self._stop_epoch = stop_epoch | |||||
| self._train_network = self._build_train_network() | self._train_network = self._build_train_network() | ||||
| self._build_eval_network(metrics, eval_network, eval_indexes) | self._build_eval_network(metrics, eval_network, eval_indexes) | ||||
| self._build_predict_network() | self._build_predict_network() | ||||
| def _process_amp_args(self, kwargs): | |||||
| if self._amp_level == "O0": | |||||
| self._keep_bn_fp32 = False | |||||
| if 'keep_batchnorm_fp32' in kwargs: | |||||
| self._keep_bn_fp32 = kwargs['keep_batchnorm_fp32'] | |||||
| if 'loss_scale_manager' in kwargs: | |||||
| self._loss_scale_manager = kwargs['loss_scale_manager'] | |||||
| self._loss_scale_manager_set = True | |||||
| def _check_kwargs(self, kwargs): | def _check_kwargs(self, kwargs): | ||||
| for arg in kwargs: | for arg in kwargs: | ||||
| if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: | if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: | ||||
| @@ -180,6 +190,9 @@ class Model: | |||||
| elif self._loss_fn: | elif self._loss_fn: | ||||
| network = nn.WithLossCell(network, self._loss_fn) | network = nn.WithLossCell(network, self._loss_fn) | ||||
| # If need to check if loss_fn is not None, but optimizer is None | # If need to check if loss_fn is not None, but optimizer is None | ||||
| if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| network.set_auto_parallel() | |||||
| return network | return network | ||||
| def _build_eval_network(self, metrics, eval_network, eval_indexes): | def _build_eval_network(self, metrics, eval_network, eval_indexes): | ||||
| @@ -198,14 +211,18 @@ class Model: | |||||
| else: | else: | ||||
| if self._loss_fn is None: | if self._loss_fn is None: | ||||
| raise ValueError("loss_fn can not be None.") | raise ValueError("loss_fn can not be None.") | ||||
| self._eval_network = nn.WithEvalCell(self._network, self._loss_fn) | |||||
| self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") | |||||
| self._eval_indexes = [0, 1, 2] | self._eval_indexes = [0, 1, 2] | ||||
| if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | |||||
| self._eval_network.set_auto_parallel() | |||||
| def _build_predict_network(self): | def _build_predict_network(self): | ||||
| """Build the network for prediction.""" | """Build the network for prediction.""" | ||||
| self._predict_network = self._network | self._predict_network = self._network | ||||
| if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | ||||
| self._predict_network = _VirtualDatasetCell(self._network) | self._predict_network = _VirtualDatasetCell(self._network) | ||||
| self._predict_network.set_auto_parallel() | |||||
| def _clear_metrics(self): | def _clear_metrics(self): | ||||
| """Clear metrics local values.""" | """Clear metrics local values.""" | ||||
| @@ -246,6 +263,94 @@ class Model: | |||||
| scaling_sens /= self._device_number | scaling_sens /= self._device_number | ||||
| return scaling_sens | return scaling_sens | ||||
| def _exec_preprocess(self, network, is_train, phase, dataset, dataset_sink_mode, iter_first_order): | |||||
| """Initializes dataset.""" | |||||
| need_wrap = False | |||||
| if dataset_sink_mode: | |||||
| # remove later to deal with loop sink | |||||
| if not hasattr(dataset, '__ME_INITED__') and context.get_context("device_target") == "Ascend" \ | |||||
| and not context.get_context("enable_ge"): | |||||
| need_wrap = True | |||||
| if not is_train: | |||||
| dataset.__loop_size__ = 1 | |||||
| dataset_helper = DatasetHelper(dataset, dataset_sink_mode, iter_first_order) | |||||
| # remove later to deal with loop sink | |||||
| if need_wrap: | |||||
| network = nn.DataWrapper(network, *(dataset_helper.types_shapes()), dataset.__ME_INITED__) | |||||
| network.set_train(is_train) | |||||
| network.phase = phase | |||||
| return dataset_helper, network | |||||
| def init(self, train_dataset=None, valid_dataset=None): | |||||
| """ | |||||
| Initializes compute graphs and data graphs with sink mode. | |||||
| Note: | |||||
| Pre-init process only supports `GRAPH_MODE` and `Ascend` target currently. | |||||
| Args: | |||||
| train_dataset (Dataset): A training dataset iterator. If define `train_dataset`, training graphs will be | |||||
| initialized. Default: None. | |||||
| valid_dataset (Dataset): A evaluating dataset iterator. If define `valid_dataset`, evaluation graphs will | |||||
| be initialized, and `metrics` in `Model` can not be None. Default: None. | |||||
| Examples: | |||||
| >>> train_dataset = get_train_dataset() | |||||
| >>> valid_dataset = get_valid_dataset() | |||||
| >>> net = Net() | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics={'acc'}) | |||||
| >>> model.init(train_dataset, valid_dataset) | |||||
| >>> model.train(2, train_dataset) | |||||
| >>> model.eval(valid_dataset) | |||||
| """ | |||||
| if context.get_context("mode") != context.GRAPH_MODE or context.get_context("device_target") != "Ascend": | |||||
| raise RuntimeError('Pre-init process only supports GRAPH MODE and Ascend target currently.') | |||||
| if not train_dataset and not valid_dataset: | |||||
| raise ValueError('Both train_dataset and valid_dataset can not be None or empty.') | |||||
| _device_number_check(self._parallel_mode, self._device_number) | |||||
| if train_dataset: | |||||
| _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) | |||||
| self._train_network.set_train() | |||||
| self._train_network.phase = 'train' | |||||
| if self._parameter_broadcast: | |||||
| self._train_network.set_broadcast_flag() | |||||
| train_dataset_helper, train_network = self._exec_preprocess(self._train_network, | |||||
| is_train=True, | |||||
| phase='train', | |||||
| dataset=train_dataset, | |||||
| dataset_sink_mode=True) | |||||
| self._train_network = train_network | |||||
| for inputs in train_dataset_helper: | |||||
| self._train_network.compile(*inputs) | |||||
| break | |||||
| if valid_dataset: | |||||
| if not self._metric_fns: | |||||
| raise RuntimeError('If define `valid_dataset`, metric fn can not be None or empty.') | |||||
| self._eval_network.set_train(False) | |||||
| self._eval_network.phase = 'eval' | |||||
| valid_dataset_helper, eval_network = self._exec_preprocess(self._eval_network, | |||||
| is_train=False, | |||||
| phase='eval', | |||||
| dataset=valid_dataset, | |||||
| dataset_sink_mode=True) | |||||
| self._eval_network = eval_network | |||||
| for inputs in valid_dataset_helper: | |||||
| self._eval_network.compile(*inputs) | |||||
| break | |||||
| def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): | def _train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): | ||||
| """ | """ | ||||
| Training. | Training. | ||||
| @@ -306,32 +411,27 @@ class Model: | |||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | list_callback (_ListCallback): Executor of callback list. Default: None. | ||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| # remove later to deal with loop sink | |||||
| iter_first_order = 277 | |||||
| iter_first_order = self._frequency - 1 | |||||
| iter_second_order = 1 | iter_second_order = 1 | ||||
| train_dataset.__loop_size__ = iter_second_order | train_dataset.__loop_size__ = iter_second_order | ||||
| need_wrap = False | |||||
| if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ | |||||
| and not context.get_context("enable_ge"): | |||||
| need_wrap = True | |||||
| dataset_helper = DatasetHelper(train_dataset, iter_first_order) | |||||
| # remove later to deal with loop sink | |||||
| if need_wrap: | |||||
| self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()), | |||||
| train_dataset.__ME_INITED__) | |||||
| cb_params.train_network = self._train_network | |||||
| self._train_network.set_train() | |||||
| dataset_helper, train_network = self._exec_preprocess(self._train_network, | |||||
| is_train=True, | |||||
| phase='train', | |||||
| dataset=train_dataset, | |||||
| dataset_sink_mode=True, | |||||
| iter_first_order=iter_first_order) | |||||
| self._train_network = train_network | |||||
| cb_params.train_network = self._train_network | |||||
| cb_params.cur_step_num = 0 | cb_params.cur_step_num = 0 | ||||
| loop_size = dataset_helper.loop_size() | loop_size = dataset_helper.loop_size() | ||||
| run_context = RunContext(cb_params) | run_context = RunContext(cb_params) | ||||
| list_callback.begin(run_context) | list_callback.begin(run_context) | ||||
| # used to stop training for early stop, such as stopAtTIme or stopATStep | # used to stop training for early stop, such as stopAtTIme or stopATStep | ||||
| should_stop = False | should_stop = False | ||||
| has_do_train1_dataset = False | |||||
| checkpoint_branch_one = True | |||||
| has_do_dataset_init = False | |||||
| switch_branch_one = True | |||||
| for i in range(epoch): | for i in range(epoch): | ||||
| cb_params.cur_epoch_num = i + 1 | cb_params.cur_epoch_num = i + 1 | ||||
| list_callback.epoch_begin(run_context) | list_callback.epoch_begin(run_context) | ||||
| @@ -339,18 +439,18 @@ class Model: | |||||
| # for data sink dataset_helper only iter once, other wise iter epoch_size times. | # for data sink dataset_helper only iter once, other wise iter epoch_size times. | ||||
| for inputs in dataset_helper: | for inputs in dataset_helper: | ||||
| list_callback.step_begin(run_context) | list_callback.step_begin(run_context) | ||||
| if checkpoint_branch_one: | |||||
| if switch_branch_one: | |||||
| cb_params.cur_step_num += loop_size | cb_params.cur_step_num += loop_size | ||||
| self._train_network.set_second_order(True) | |||||
| self._train_network.add_flags_recursive(thor=True) | |||||
| self._train_network.phase = 'train0' | self._train_network.phase = 'train0' | ||||
| else: | else: | ||||
| cb_params.cur_step_num += iter_first_order | cb_params.cur_step_num += iter_first_order | ||||
| self._train_network.set_second_order(False) | |||||
| self._train_network.add_flags_recursive(thor=False) | |||||
| self._train_network.phase = 'train1' | self._train_network.phase = 'train1' | ||||
| if not has_do_train1_dataset: | |||||
| if not has_do_dataset_init: | |||||
| _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') | _exec_datagraph(train_dataset, iter_first_order, phase='train1_dataset') | ||||
| has_do_train1_dataset = True | |||||
| checkpoint_branch_one = not checkpoint_branch_one | |||||
| has_do_dataset_init = True | |||||
| switch_branch_one = not switch_branch_one | |||||
| outputs = self._train_network(*inputs) | outputs = self._train_network(*inputs) | ||||
| cb_params.net_outputs = outputs | cb_params.net_outputs = outputs | ||||
| list_callback.step_end(run_context) | list_callback.step_end(run_context) | ||||
| @@ -376,17 +476,21 @@ class Model: | |||||
| list_callback (_ListCallback): Executor of callback list. Default: None. | list_callback (_ListCallback): Executor of callback list. Default: None. | ||||
| cb_params (_InternalCallbackParam): Callback parameters. Default: None. | cb_params (_InternalCallbackParam): Callback parameters. Default: None. | ||||
| """ | """ | ||||
| dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) | |||||
| dataset_helper, _ = self._exec_preprocess(self._train_network, | |||||
| is_train=True, | |||||
| phase='train', | |||||
| dataset=train_dataset, | |||||
| dataset_sink_mode=False) | |||||
| cb_params.cur_step_num = 0 | cb_params.cur_step_num = 0 | ||||
| run_context = RunContext(cb_params) | run_context = RunContext(cb_params) | ||||
| _callback_wrapper(list_callback, run_context, "begin") | |||||
| list_callback.begin(run_context) | |||||
| # used to stop training for early stop, such as stopAtTIme or stopATStep | # used to stop training for early stop, such as stopAtTIme or stopATStep | ||||
| should_stop = False | should_stop = False | ||||
| for i in range(epoch): | for i in range(epoch): | ||||
| cb_params.cur_epoch_num = i + 1 | cb_params.cur_epoch_num = i + 1 | ||||
| _callback_wrapper(list_callback, run_context, "epoch_begin") | |||||
| list_callback.epoch_begin(run_context) | |||||
| for next_element in dataset_helper: | for next_element in dataset_helper: | ||||
| len_element = len(next_element) | len_element = len(next_element) | ||||
| @@ -394,7 +498,7 @@ class Model: | |||||
| raise ValueError("when loss_fn is not None, train_dataset should" | raise ValueError("when loss_fn is not None, train_dataset should" | ||||
| "return two elements, but got {}".format(len_element)) | "return two elements, but got {}".format(len_element)) | ||||
| cb_params.cur_step_num += 1 | cb_params.cur_step_num += 1 | ||||
| _callback_wrapper(list_callback, run_context, "step_begin") | |||||
| list_callback.step_begin(run_context) | |||||
| overflow = False | overflow = False | ||||
| if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): | if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): | ||||
| @@ -408,19 +512,19 @@ class Model: | |||||
| overflow = np.all(overflow.asnumpy()) | overflow = np.all(overflow.asnumpy()) | ||||
| self._loss_scale_manager.update_loss_scale(overflow) | self._loss_scale_manager.update_loss_scale(overflow) | ||||
| _callback_wrapper(list_callback, run_context, "step_end") | |||||
| list_callback.step_end(run_context) | |||||
| should_stop = should_stop or run_context.get_stop_requested() | should_stop = should_stop or run_context.get_stop_requested() | ||||
| if should_stop: | if should_stop: | ||||
| break | break | ||||
| train_dataset.reset() | train_dataset.reset() | ||||
| _callback_wrapper(list_callback, run_context, "epoch_end") | |||||
| list_callback.epoch_end(run_context) | |||||
| should_stop = should_stop or run_context.get_stop_requested() | should_stop = should_stop or run_context.get_stop_requested() | ||||
| if should_stop: | if should_stop: | ||||
| break | break | ||||
| _callback_wrapper(list_callback, run_context, "end") | |||||
| list_callback.end(run_context) | |||||
| def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): | def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): | ||||
| """ | """ | ||||
| @@ -452,7 +556,7 @@ class Model: | |||||
| Examples: | Examples: | ||||
| >>> dataset = get_dataset() | >>> dataset = get_dataset() | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| >>> loss_scale_manager = FixedLossScaleManager() | >>> loss_scale_manager = FixedLossScaleManager() | ||||
| >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | >>> optim = Momentum(params=net.trainable_params(), learning_rate=0.1, momentum=0.9) | ||||
| >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) | >>> model = Model(net, loss_fn=loss, optimizer=optim, metrics=None, loss_scale_manager=loss_scale_manager) | ||||
| @@ -465,9 +569,6 @@ class Model: | |||||
| _device_number_check(self._parallel_mode, self._device_number) | _device_number_check(self._parallel_mode, self._device_number) | ||||
| _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) | _parameter_broadcast_check(self._parallel_mode, self._parameter_broadcast) | ||||
| if context.get_context("device_target") in ["CPU", "GPU"] and context.get_context("enable_loop_sink"): | |||||
| raise ValueError("CPU and GPU can't support loop sink, please set enable_loop_sink=False.") | |||||
| self._train(epoch, | self._train(epoch, | ||||
| train_dataset, | train_dataset, | ||||
| callbacks=callbacks, | callbacks=callbacks, | ||||
| @@ -485,25 +586,15 @@ class Model: | |||||
| Returns: | Returns: | ||||
| Dict, returns the loss value & metrics values for the model in test mode. | Dict, returns the loss value & metrics values for the model in test mode. | ||||
| """ | """ | ||||
| _device_number_check(self._parallel_mode, self._device_number) | |||||
| run_context = RunContext(cb_params) | run_context = RunContext(cb_params) | ||||
| # remove later to deal with loop sink | |||||
| need_wrap = False | |||||
| if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ | |||||
| and not context.get_context("enable_ge"): | |||||
| need_wrap = True | |||||
| valid_dataset.__loop_size__ = 1 | |||||
| dataset_helper = DatasetHelper(valid_dataset) | |||||
| # remove later to deal with loop sink | |||||
| if need_wrap: | |||||
| self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), | |||||
| valid_dataset.__ME_INITED__) | |||||
| self._eval_network.set_train(mode=False) | |||||
| self._eval_network.phase = 'eval' | |||||
| dataset_helper, eval_network = self._exec_preprocess(self._eval_network, | |||||
| is_train=False, | |||||
| phase='eval', | |||||
| dataset=valid_dataset, | |||||
| dataset_sink_mode=True) | |||||
| self._eval_network = eval_network | |||||
| cb_params.eval_network = self._eval_network | |||||
| list_callback.begin(run_context) | list_callback.begin(run_context) | ||||
| for inputs in dataset_helper: | for inputs in dataset_helper: | ||||
| @@ -537,7 +628,11 @@ class Model: | |||||
| run_context = RunContext(cb_params) | run_context = RunContext(cb_params) | ||||
| list_callback.begin(run_context) | list_callback.begin(run_context) | ||||
| dataset_helper = DatasetHelper(valid_dataset, dataset_sink_mode=False) | |||||
| dataset_helper, _ = self._exec_preprocess(self._eval_network, | |||||
| is_train=False, | |||||
| phase='eval', | |||||
| dataset=valid_dataset, | |||||
| dataset_sink_mode=False) | |||||
| for next_element in dataset_helper: | for next_element in dataset_helper: | ||||
| cb_params.cur_step_num += 1 | cb_params.cur_step_num += 1 | ||||
| list_callback.step_begin(run_context) | list_callback.step_begin(run_context) | ||||
| @@ -574,11 +669,12 @@ class Model: | |||||
| Examples: | Examples: | ||||
| >>> dataset = get_dataset() | >>> dataset = get_dataset() | ||||
| >>> net = Net() | >>> net = Net() | ||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits() | |||||
| >>> loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) | |||||
| >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) | >>> model = Model(net, loss_fn=loss, optimizer=None, metrics={'acc'}) | ||||
| >>> model.eval(dataset) | >>> model.eval(dataset) | ||||
| """ | """ | ||||
| check_bool(dataset_sink_mode) | check_bool(dataset_sink_mode) | ||||
| _device_number_check(self._parallel_mode, self._device_number) | |||||
| if not self._metric_fns: | if not self._metric_fns: | ||||
| raise ValueError("metric fn can not be None or empty.") | raise ValueError("metric fn can not be None or empty.") | ||||
| @@ -14,22 +14,24 @@ | |||||
| # ============================================================================ | # ============================================================================ | ||||
| """ResNet.""" | """ResNet.""" | ||||
| import math | import math | ||||
| import mindspore.nn as nn | |||||
| import numpy as np | import numpy as np | ||||
| import mindspore.nn as nn | |||||
| from mindspore.common.tensor import Tensor | from mindspore.common.tensor import Tensor | ||||
| from mindspore.ops import operations as P | from mindspore.ops import operations as P | ||||
| from second_order.thor_layer import Conv2d_Thor, Dense_Thor | |||||
| from model.thor_layer import Conv2d_Thor, Dense_Thor | |||||
| def calculate_gain(nonlinearity, param=None): | def calculate_gain(nonlinearity, param=None): | ||||
| """calculate_gain""" | |||||
| linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | ||||
| res = 0 | |||||
| if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | ||||
| return 1 | |||||
| res = 1 | |||||
| elif nonlinearity == 'tanh': | elif nonlinearity == 'tanh': | ||||
| return 5.0 / 3 | |||||
| res = 5.0 / 3 | |||||
| elif nonlinearity == 'relu': | elif nonlinearity == 'relu': | ||||
| return math.sqrt(2.0) | |||||
| res = math.sqrt(2.0) | |||||
| elif nonlinearity == 'leaky_relu': | elif nonlinearity == 'leaky_relu': | ||||
| if param is None: | if param is None: | ||||
| negative_slope = 0.01 | negative_slope = 0.01 | ||||
| @@ -38,16 +40,17 @@ def calculate_gain(nonlinearity, param=None): | |||||
| negative_slope = param | negative_slope = param | ||||
| else: | else: | ||||
| raise ValueError("negative_slope {} not a valid number".format(param)) | raise ValueError("negative_slope {} not a valid number".format(param)) | ||||
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||||
| res = math.sqrt(2.0 / (1 + negative_slope ** 2)) | |||||
| else: | else: | ||||
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) | ||||
| return res | |||||
| def _calculate_fan_in_and_fan_out(tensor): | def _calculate_fan_in_and_fan_out(tensor): | ||||
| """_calculate_fan_in_and_fan_out""" | |||||
| dimensions = len(tensor) | dimensions = len(tensor) | ||||
| if dimensions < 2: | if dimensions < 2: | ||||
| raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") | raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") | ||||
| if dimensions == 2: # Linear | if dimensions == 2: # Linear | ||||
| fan_in = tensor[1] | fan_in = tensor[1] | ||||
| fan_out = tensor[0] | fan_out = tensor[0] | ||||
| @@ -67,7 +70,6 @@ def _calculate_correct_fan(tensor, mode): | |||||
| valid_modes = ['fan_in', 'fan_out'] | valid_modes = ['fan_in', 'fan_out'] | ||||
| if mode not in valid_modes: | if mode not in valid_modes: | ||||
| raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) | ||||
| fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) | fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) | ||||
| return fan_in if mode == 'fan_in' else fan_out | return fan_in if mode == 'fan_in' else fan_out | ||||
| @@ -93,8 +95,6 @@ def _conv3x3(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, freq | |||||
| return Conv2d_Thor(in_channel, out_channel, | return Conv2d_Thor(in_channel, out_channel, | ||||
| kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, | kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight, | ||||
| damping=damping, loss_scale=loss_scale, frequency=frequency) | damping=damping, loss_scale=loss_scale, frequency=frequency) | ||||
| # return nn.Conv2d(in_channel, out_channel, | |||||
| # kernel_size=3, stride=stride, padding=0, pad_mode='same', weight_init=weight) | |||||
| def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): | def _conv1x1(in_channel, out_channel, stride=1, damping=0.03, loss_scale=1, frequency=278): | ||||
| @@ -125,7 +125,7 @@ def _bn_last(channel): | |||||
| def _fc(in_channel, out_channel, damping, loss_scale, frequency): | def _fc(in_channel, out_channel, damping, loss_scale, frequency): | ||||
| weight_shape = (out_channel, in_channel) | weight_shape = (out_channel, in_channel) | ||||
| weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5)) | |||||
| weight = Tensor(kaiming_uniform(weight_shape, a=math.sqrt(5))) | |||||
| return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, | return Dense_Thor(in_channel, out_channel, has_bias=False, weight_init=weight, | ||||
| bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency) | bias_init=0, damping=damping, loss_scale=loss_scale, frequency=frequency) | ||||
| @@ -133,15 +133,15 @@ def _fc(in_channel, out_channel, damping, loss_scale, frequency): | |||||
| class ResidualBlock(nn.Cell): | class ResidualBlock(nn.Cell): | ||||
| """ | """ | ||||
| ResNet V1 residual block definition. | ResNet V1 residual block definition. | ||||
| Args: | Args: | ||||
| in_channel (int): Input channel. | in_channel (int): Input channel. | ||||
| out_channel (int): Output channel. | out_channel (int): Output channel. | ||||
| stride (int): Stride size for the first convolutional layer. Default: 1. | stride (int): Stride size for the first convolutional layer. Default: 1. | ||||
| Returns: | Returns: | ||||
| Tensor, output tensor. | Tensor, output tensor. | ||||
| Examples: | Examples: | ||||
| >>> ResidualBlock(3, 256, stride=2) | >>> ResidualBlock(3, 256, stride=2) | ||||
| """ | """ | ||||
| @@ -210,7 +210,7 @@ class ResidualBlock(nn.Cell): | |||||
| class ResNet(nn.Cell): | class ResNet(nn.Cell): | ||||
| """ | """ | ||||
| ResNet architecture. | ResNet architecture. | ||||
| Args: | Args: | ||||
| block (Cell): Block for network. | block (Cell): Block for network. | ||||
| layer_nums (list): Numbers of block in different layers. | layer_nums (list): Numbers of block in different layers. | ||||
| @@ -220,7 +220,7 @@ class ResNet(nn.Cell): | |||||
| num_classes (int): The number of classes that the training images are belonging to. | num_classes (int): The number of classes that the training images are belonging to. | ||||
| Returns: | Returns: | ||||
| Tensor, output tensor. | Tensor, output tensor. | ||||
| Examples: | Examples: | ||||
| >>> ResNet(ResidualBlock, | >>> ResNet(ResidualBlock, | ||||
| >>> [3, 4, 6, 3], | >>> [3, 4, 6, 3], | ||||
| @@ -290,17 +290,17 @@ class ResNet(nn.Cell): | |||||
| damping, loss_scale, frequency): | damping, loss_scale, frequency): | ||||
| """ | """ | ||||
| Make stage network of ResNet. | Make stage network of ResNet. | ||||
| Args: | Args: | ||||
| block (Cell): Resnet block. | block (Cell): Resnet block. | ||||
| layer_num (int): Layer number. | layer_num (int): Layer number. | ||||
| in_channel (int): Input channel. | in_channel (int): Input channel. | ||||
| out_channel (int): Output channel. | out_channel (int): Output channel. | ||||
| stride (int): Stride size for the first convolutional layer. | stride (int): Stride size for the first convolutional layer. | ||||
| Returns: | Returns: | ||||
| SequentialCell, the output layer. | SequentialCell, the output layer. | ||||
| Examples: | Examples: | ||||
| >>> _make_layer(ResidualBlock, 3, 128, 256, 2) | >>> _make_layer(ResidualBlock, 3, 128, 256, 2) | ||||
| """ | """ | ||||
| @@ -321,7 +321,7 @@ class ResNet(nn.Cell): | |||||
| x = self.conv1(x) | x = self.conv1(x) | ||||
| x = self.bn1(x) | x = self.bn1(x) | ||||
| x = self.relu(x) | x = self.relu(x) | ||||
| c1, argmax = self.maxpool(x) | |||||
| c1, _ = self.maxpool(x) | |||||
| c2 = self.layer1(c1) | c2 = self.layer1(c1) | ||||
| c3 = self.layer2(c2) | c3 = self.layer2(c2) | ||||
| @@ -338,13 +338,13 @@ class ResNet(nn.Cell): | |||||
| def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): | def resnet50(class_num=10, damping=0.03, loss_scale=1, frequency=278): | ||||
| """ | """ | ||||
| Get ResNet50 neural network. | Get ResNet50 neural network. | ||||
| Args: | Args: | ||||
| class_num (int): Class number. | class_num (int): Class number. | ||||
| Returns: | Returns: | ||||
| Cell, cell instance of ResNet50 neural network. | Cell, cell instance of ResNet50 neural network. | ||||
| Examples: | Examples: | ||||
| >>> net = resnet50(10) | >>> net = resnet50(10) | ||||
| """ | """ | ||||
| @@ -51,6 +51,6 @@ do | |||||
| echo "start training for rank $RANK_ID, device $DEVICE_ID" | echo "start training for rank $RANK_ID, device $DEVICE_ID" | ||||
| env > env.log | env > env.log | ||||
| python train_0517_1.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & | |||||
| python train.py --do_train=True --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$2 > log 2>&1 & | |||||
| cd .. | cd .. | ||||
| done | done | ||||
| @@ -17,7 +17,6 @@ import argparse | |||||
| import os | import os | ||||
| import random | import random | ||||
| import mindspore.dataset.engine as de | |||||
| from mindspore import Tensor | from mindspore import Tensor | ||||
| from mindspore import context | from mindspore import context | ||||
| from mindspore.communication.management import init | from mindspore.communication.management import init | ||||
| @@ -25,19 +24,17 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context | |||||
| from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor | ||||
| from mindspore.train.loss_scale_manager import FixedLossScaleManager | from mindspore.train.loss_scale_manager import FixedLossScaleManager | ||||
| from mindspore.train.model import ParallelMode | from mindspore.train.model import ParallelMode | ||||
| from second_order.model_second_order import Model | |||||
| from second_order.resnet import resnet50 | |||||
| from second_order.thor import THOR | |||||
| from model.model_thor import Model | |||||
| from model.resnet import resnet50 | |||||
| from model.thor import THOR | |||||
| import numpy as np | import numpy as np | ||||
| from config_imagenet import config | |||||
| from config import config | |||||
| from crossentropy import CrossEntropy | from crossentropy import CrossEntropy | ||||
| from dataset_imagenet import create_dataset | from dataset_imagenet import create_dataset | ||||
| from lr_generator import warmup_cosine_annealing_lr | |||||
| random.seed(1) | random.seed(1) | ||||
| np.random.seed(1) | np.random.seed(1) | ||||
| de.config.set_seed(1) | |||||
| parser = argparse.ArgumentParser(description='Image classification') | parser = argparse.ArgumentParser(description='Image classification') | ||||
| parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | parser.add_argument('--run_distribute', type=bool, default=False, help='Run distribute') | ||||
| @@ -50,29 +47,29 @@ args_opt = parser.parse_args() | |||||
| device_id = int(os.getenv('DEVICE_ID')) | device_id = int(os.getenv('DEVICE_ID')) | ||||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=device_id) | context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=device_id) | ||||
| context.set_context(enable_task_sink=True) | |||||
| context.set_context(enable_loop_sink=True) | |||||
| context.set_context(enable_mem_reuse=True) | |||||
| def get_second_order_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): | |||||
| """get_second_order_lr""" | |||||
| def get_model_lr(global_step, lr_init, decay, total_epochs, steps_per_epoch): | |||||
| """get_model_lr""" | |||||
| lr_each_step = [] | lr_each_step = [] | ||||
| total_steps = steps_per_epoch * total_epochs | total_steps = steps_per_epoch * total_epochs | ||||
| for i in range(total_steps): | for i in range(total_steps): | ||||
| epoch = (i + 1) / steps_per_epoch | epoch = (i + 1) / steps_per_epoch | ||||
| base = (1.0 - float(epoch) / total_epochs) ** decay | base = (1.0 - float(epoch) / total_epochs) ** decay | ||||
| lr_local = lr_init * base | lr_local = lr_init * base | ||||
| if epoch >= 39: | |||||
| lr_local = lr_local * 0.5 | |||||
| if epoch >= 40: | |||||
| lr_local = lr_local * 0.5 | |||||
| lr_each_step.append(lr_local) | lr_each_step.append(lr_local) | ||||
| current_step = global_step | current_step = global_step | ||||
| lr_each_step = np.array(lr_each_step).astype(np.float32) | lr_each_step = np.array(lr_each_step).astype(np.float32) | ||||
| print("learning_rate_is=====", lr_each_step) | |||||
| learning_rate = lr_each_step[current_step:] | learning_rate = lr_each_step[current_step:] | ||||
| return learning_rate | return learning_rate | ||||
| def get_second_order_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch): | |||||
| """get_second_order_damping""" | |||||
| def get_model_damping(global_step, damping_init, decay_rate, total_epochs, steps_per_epoch): | |||||
| """get_model_damping""" | |||||
| damping_each_step = [] | damping_each_step = [] | ||||
| total_steps = steps_per_epoch * total_epochs | total_steps = steps_per_epoch * total_epochs | ||||
| for step in range(total_steps): | for step in range(total_steps): | ||||
| @@ -83,26 +80,23 @@ def get_second_order_damping(global_step, damping_init, decay_rate, total_epochs | |||||
| current_step = global_step | current_step = global_step | ||||
| damping_each_step = np.array(damping_each_step).astype(np.float32) | damping_each_step = np.array(damping_each_step).astype(np.float32) | ||||
| damping_now = damping_each_step[current_step:] | damping_now = damping_each_step[current_step:] | ||||
| print("damping_is=========", damping_now) | |||||
| return damping_now | return damping_now | ||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| if args_opt.do_eval: | |||||
| print("eval") | |||||
| else: | |||||
| if args_opt.run_distribute: | |||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True, parameter_broadcast=True) | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([80], "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") | |||||
| init() | |||||
| else: | |||||
| print(" ") | |||||
| if not args_opt.do_eval and args_opt.run_distribute: | |||||
| context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, | |||||
| mirror_mean=True, parameter_broadcast=True) | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([107], "hccl_world_groupsum1") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum2") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum3") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum4") | |||||
| auto_parallel_context().set_all_reduce_fusion_split_indices([27], "hccl_world_groupsum5") | |||||
| init() | |||||
| epoch_size = config.epoch_size | epoch_size = config.epoch_size | ||||
| damping = get_second_order_damping(0, 0.03, 0.87, 50, 5004) | |||||
| damping = get_model_damping(0, 0.03, 0.87, 50, 5004) | |||||
| net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, | net = resnet50(class_num=config.class_num, damping=damping, loss_scale=config.loss_scale, | ||||
| frequency=config.frequency) | frequency=config.frequency) | ||||
| @@ -115,17 +109,12 @@ if __name__ == '__main__': | |||||
| step_size = dataset.get_dataset_size() | step_size = dataset.get_dataset_size() | ||||
| loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) | ||||
| lr = Tensor(warmup_cosine_annealing_lr(0.035, | |||||
| step_size, | |||||
| config.warmup_epochs, | |||||
| 50, | |||||
| config.T_max, | |||||
| config.eta_min)) | |||||
| opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, | |||||
| config.momentum, damping, config.frequency, | |||||
| lr = Tensor(get_model_lr(0, 0.05, 6, 70, 5004)) | |||||
| opt = THOR(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, | |||||
| filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), | filter(lambda x: 'matrix_A' in x.name, net.get_parameters()), | ||||
| filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), | filter(lambda x: 'matrix_G' in x.name, net.get_parameters()), | ||||
| filter(lambda x: 'spatial_norm' in x.name, net.get_parameters()), | |||||
| filter(lambda x: 'A_inv_max' in x.name, net.get_parameters()), | |||||
| filter(lambda x: 'G_inv_max' in x.name, net.get_parameters()), | |||||
| config.weight_decay, config.loss_scale) | config.weight_decay, config.loss_scale) | ||||
| model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, | model = Model(net, loss_fn=loss, optimizer=opt, amp_level='O2', loss_scale_manager=loss_scale, | ||||
| @@ -0,0 +1,76 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """batch_matmul_impl""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusBatchMatMul", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "batchmatmul.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusBatchMatMul", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 1, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x2", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def CusBatchMatMul(input_x1, input_x2, output, transpose_a=False, transpose_b=True, kernel_name="batchmatmul"): | |||||
| """CusBatchMatMul""" | |||||
| return | |||||
| @@ -0,0 +1,64 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """CusCholeskyTrsm""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusCholeskyTrsm", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "choleskytrsm.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusCholeskyTrsm", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def CusCholeskyTrsm(input_x, output, kernel_name): | |||||
| """CusCholeskyTrsm""" | |||||
| return | |||||
| @@ -0,0 +1,69 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """CusFusedAbsMax1""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusFusedAbsMax1", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "fusedabsmax1.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusFusedAbsMax1", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| { | |||||
| "name": "origin_shape", | |||||
| "param_type": "required", | |||||
| "type": "listInt", | |||||
| "value": "all" | |||||
| } | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def CusFusedAbsMax1(input_x, output, origin_shape=None, kernel_name="fused_abs_max1"): | |||||
| """CusFusedAbsMax1""" | |||||
| return | |||||
| @@ -0,0 +1,87 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """CusImg2ColNC1HWC0""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusImg2ColNC1HWC0", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "img2colnc1hwc0.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusImg2ColNC1HWC0", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| { | |||||
| "name": "ksizes", | |||||
| "param_type": "required", | |||||
| "type": "listInt", | |||||
| "value": "all" | |||||
| }, | |||||
| { | |||||
| "name": "strides", | |||||
| "param_type": "required", | |||||
| "type": "listInt", | |||||
| "value": "all" | |||||
| }, | |||||
| { | |||||
| "name": "dilates", | |||||
| "param_type": "required", | |||||
| "type": "listInt", | |||||
| "value": "all" | |||||
| }, | |||||
| { | |||||
| "name": "padding", | |||||
| "param_type": "required", | |||||
| "type": "str", | |||||
| "value": "all" | |||||
| } | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "NC1HWC0" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "FRACTAL_NZ" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def CusImg2ColNC1HWC0(input_x, output, ksizes, strides, dilates, padding, kernel_name="img2col"): | |||||
| """CusImg2ColNC1HWC0""" | |||||
| return | |||||
| @@ -0,0 +1,101 @@ | |||||
| # -*- coding:utf-8 -*- | |||||
| """ | |||||
| copyright 2020 Huawei Technologies Co., Ltd | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License == distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| matmul | |||||
| """ | |||||
| from __future__ import absolute_import | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| from topi.cce import util | |||||
| # General limitation of the size for input shape: 2**31 | |||||
| SHAPE_SIZE_LIMIT = 2147483648 | |||||
| NoneType = type(None) | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusMatMulCubeDenseLeft", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "matmulcubedenseleft.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusMatMulCubeDenseLeft", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 1, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "FRACTAL_NZ" | |||||
| ], | |||||
| "name": "x2", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 2, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x3", | |||||
| "need_compile": false, | |||||
| "param_type": "optional", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "FRACTAL_NZ" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) | |||||
| def CusMatMulCubeDenseLeft(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, | |||||
| kernel_name="matmulcube"): | |||||
| """CusMatMulCubeDenseLeft""" | |||||
| return | |||||
| @@ -0,0 +1,102 @@ | |||||
| # -*- coding:utf-8 -*- | |||||
| """ | |||||
| copyright 2020 Huawei Technologies Co., Ltd | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License == distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| matmul | |||||
| """ | |||||
| from __future__ import absolute_import | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| from topi.cce import util | |||||
| # General limitation of the size for input shape: 2**31 | |||||
| SHAPE_SIZE_LIMIT = 2147483648 | |||||
| NoneType = type(None) | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusMatMulCubeFraczLeftCast", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "matmulcubefraczleftcast.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusMatMulCubeFraczLeftCast", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 1, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "FracZ" | |||||
| ], | |||||
| "name": "x2", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 2, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x3", | |||||
| "need_compile": false, | |||||
| "param_type": "optional", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "FracZ" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| # pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements | |||||
| @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) | |||||
| def CusMatMulCubeFraczLeftCast(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, | |||||
| kernel_name="CusMatMulCubeFraczLeftCast"): | |||||
| """CusMatMulCubeFraczLeftCast""" | |||||
| return | |||||
| @@ -0,0 +1,113 @@ | |||||
| #!/usr/bin/env python | |||||
| # -*- coding:utf-8 -*- | |||||
| """ | |||||
| copyright 2020 Huawei Technologies Co., Ltd | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License == distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| matmul | |||||
| """ | |||||
| from __future__ import absolute_import | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| # General limitation of the size for input shape: 2**31 | |||||
| SHAPE_SIZE_LIMIT = 2147483648 | |||||
| NoneType = type(None) | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusMatMulCubeFraczRightMul", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "matmulcubefraczrightmul.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusMatMulCubeFraczRightMul", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "FracZ" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 1, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x2", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 2, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x3", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 3, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x4", | |||||
| "need_compile": false, | |||||
| "param_type": "optional", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "FracZ" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def CusMatMulCubeFraczRightMul(input_x1, input_x2, input_x3, bias=None, output_y={}, trans_a=False, trans_b=False, | |||||
| kernel_name="matmulcube"): | |||||
| """CusMatMulCubeFraczRightMul""" | |||||
| return | |||||
| @@ -0,0 +1,114 @@ | |||||
| #!/usr/bin/env python | |||||
| # -*- coding:utf-8 -*- | |||||
| """ | |||||
| copyright 2020 Huawei Technologies Co., Ltd | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License == distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| matmul | |||||
| """ | |||||
| from __future__ import absolute_import | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| from topi.cce import util | |||||
| # General limitation of the size for input shape: 2**31 | |||||
| SHAPE_SIZE_LIMIT = 2147483648 | |||||
| NoneType = type(None) | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusMatMulCube", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "matmulcube.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusMatMulCube", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| { | |||||
| "name": "transpose_a", | |||||
| "param_type": "required", | |||||
| "type": "bool", | |||||
| "value": "all" | |||||
| }, | |||||
| { | |||||
| "name": "transpose_b", | |||||
| "param_type": "required", | |||||
| "type": "bool", | |||||
| "value": "all" | |||||
| } | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "FRACTAL_NZ" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 1, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "FRACTAL_NZ" | |||||
| ], | |||||
| "name": "x2", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| }, | |||||
| { | |||||
| "index": 2, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x3", | |||||
| "need_compile": false, | |||||
| "param_type": "optional", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "FRACTAL_NZ" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| # pylint: disable=locally-disabled,too-many-arguments, too-many-locals, too-many-statements | |||||
| @util.check_input_type(dict, dict, (dict, NoneType), dict, bool, bool, str) | |||||
| def CusMatMulCube(input_x1, input_x2, bias=None, output_y={}, trans_a=False, trans_b=False, kernel_name="matmulcube"): | |||||
| """CusMatMulCube""" | |||||
| return | |||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """CusMatrixCombine""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusMatrixCombine", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "matrixcombine.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusMatrixCombine", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float32" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def CusMatrixCombine(input_x, output, kernel_name="matrix_combine"): | |||||
| """CusMatrixCombine""" | |||||
| return | |||||
| @@ -0,0 +1,63 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """CusTranspose02314""" | |||||
| from mindspore.ops.op_info_register import op_info_register | |||||
| @op_info_register("""{ | |||||
| "op_name": "CusTranspose02314", | |||||
| "imply_type": "TBE", | |||||
| "fusion_type": "OPAQUE", | |||||
| "async_flag": false, | |||||
| "binfile_name": "transpose02314.so", | |||||
| "compute_cost": 10, | |||||
| "kernel_name": "CusTranspose02314", | |||||
| "partial_flag": true, | |||||
| "attr": [ | |||||
| ], | |||||
| "inputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "NC1HWC0" | |||||
| ], | |||||
| "name": "x1", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ], | |||||
| "outputs": [ | |||||
| { | |||||
| "index": 0, | |||||
| "dtype": [ | |||||
| "float16" | |||||
| ], | |||||
| "format": [ | |||||
| "DefaultFormat" | |||||
| ], | |||||
| "name": "y", | |||||
| "need_compile": false, | |||||
| "param_type": "required", | |||||
| "shape": "all" | |||||
| } | |||||
| ] | |||||
| }""") | |||||
| def CusTranspose02314(input_x, output, kernel_name="transpose021354"): | |||||
| """CusTranspose02314""" | |||||
| return | |||||
| @@ -0,0 +1,248 @@ | |||||
| # Copyright 2020 Huawei Technologies Co., Ltd | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| # ============================================================================ | |||||
| """thor_ops""" | |||||
| import mindspore as ms | |||||
| from mindspore.ops import prim_attr_register, PrimitiveWithInfer | |||||
| from mindspore.ops.composite import multitype_ops as C | |||||
| class CusBatchMatMul(PrimitiveWithInfer): | |||||
| """CusMatMulCube definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init CusMatMulCube""" | |||||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) | |||||
| def get_bprop(self): | |||||
| def bprop(x1, x2, out, dout): | |||||
| return (C.zeros_like(x1), C.zeros_like(x2)) | |||||
| return bprop | |||||
| def infer_shape(self, data1_shape, data2_shape): | |||||
| return data1_shape | |||||
| def infer_dtype(self, data1_dtype, data2_dtype): | |||||
| return data1_dtype | |||||
| class CusCholeskyTrsm(PrimitiveWithInfer): | |||||
| """CusCholeskyTrsm definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init CusCholeskyTrsm""" | |||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||||
| def infer_shape(self, data1_shape): | |||||
| ll = [] | |||||
| m, _ = data1_shape | |||||
| if m >= 128: | |||||
| ll = [m // 128, 128, 128] | |||||
| else: | |||||
| ll = [1, 64, 64] | |||||
| return ll | |||||
| def infer_dtype(self, data1_dtype): | |||||
| return data1_dtype | |||||
| class CusFusedAbsMax1(PrimitiveWithInfer): | |||||
| """CusCholeskyTrsm definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self, origin_shape=[-1, -1]): | |||||
| """init CusCholeskyTrsm""" | |||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||||
| self.origin_shape = origin_shape | |||||
| def get_bprop(self): | |||||
| def bprop(x, out, dout): | |||||
| return (C.zeros_like(x),) | |||||
| return bprop | |||||
| def infer_shape(self, data1_shape): | |||||
| ll = [] | |||||
| if len(data1_shape) == 2: | |||||
| ll = [1,] | |||||
| else: | |||||
| ll = [32, 64] | |||||
| return ll | |||||
| def infer_dtype(self, data1_dtype): | |||||
| return data1_dtype | |||||
| class CusImg2Col(PrimitiveWithInfer): | |||||
| """CusImg2Col definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self, ksizes, strides, dilates=(1, 1, 1, 1), mode="NC1HWC0"): | |||||
| """init CusImg2Col""" | |||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||||
| self.ksizes = ksizes | |||||
| self.strides = strides | |||||
| self.dilates = dilates | |||||
| self.mode = mode | |||||
| def get_bprop(self): | |||||
| def bprop(x, out, dout): | |||||
| return (C.zeros_like(x),) | |||||
| return bprop | |||||
| def infer_shape(self, data1_shape): | |||||
| bs, c, h, w = data1_shape | |||||
| _, stride_h, stride_w, _ = self.strides | |||||
| _, k_w, k_h, _ = self.ksizes | |||||
| # assert m == n | |||||
| c0 = 16 | |||||
| c1 = c // 16 | |||||
| if c1 == 0: | |||||
| c1 = 1 | |||||
| shape = [bs * int(h // stride_h) * int(w // stride_w), k_w * k_h * c1 * c0] | |||||
| return shape | |||||
| def infer_dtype(self, data1_dtype): | |||||
| return data1_dtype | |||||
| class CusMatMulCubeDenseLeft(PrimitiveWithInfer): | |||||
| """CusMatMulCube definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init CusMatMulCube""" | |||||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) | |||||
| def get_bprop(self): | |||||
| def bprop(x1, x2, out, dout): | |||||
| return (C.zeros_like(x1), C.zeros_like(x2)) | |||||
| return bprop | |||||
| def infer_shape(self, data1_shape, data2_shape): | |||||
| return data2_shape | |||||
| def infer_dtype(self, data1_dtype, data2_dtype): | |||||
| return ms.common.dtype.tensor_type(getattr(ms, "float16")) | |||||
| class CusMatMulCubeFraczRightMul(PrimitiveWithInfer): | |||||
| """CusMatMulCubeFraczRightMul definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init CusMatMulCubeFraczRightMul""" | |||||
| self.init_prim_io_names(inputs=['x1', 'x2', 'x3'], outputs=['y']) | |||||
| def get_bprop(self): | |||||
| def bprop(x1, x2, x3, out, dout): | |||||
| return (C.zeros_like(x1), C.zeros_like(x2), C.zeros_like(x3)) | |||||
| return bprop | |||||
| def infer_shape(self, data1_shape, data2_shape, data3_shape): | |||||
| return data1_shape | |||||
| def infer_dtype(self, data1_dtype, data2_dtype, data3_dtype): | |||||
| return ms.common.dtype.tensor_type(getattr(ms, "float32")) | |||||
| class CusMatMulCube(PrimitiveWithInfer): | |||||
| """CusMatMulCube definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self, transpose_a=False, transpose_b=False): | |||||
| """init CusMatMulCube""" | |||||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) | |||||
| self.transpose_a = transpose_a | |||||
| self.transpose_b = transpose_b | |||||
| def get_bprop(self): | |||||
| def bprop(x1, x2, out, dout): | |||||
| return (C.zeros_like(x1), C.zeros_like(x2)) | |||||
| return bprop | |||||
| def infer_shape(self, data1_shape, data2_shape): | |||||
| # shape = [1, data1_shape[1], data2_shape[2], 16, 16] | |||||
| # return shape | |||||
| if self.transpose_a: | |||||
| k1, m = data1_shape | |||||
| else: | |||||
| m, k1 = data1_shape | |||||
| if self.transpose_b: | |||||
| n, k2 = data2_shape | |||||
| else: | |||||
| k2, n = data2_shape | |||||
| assert k1 == k2 | |||||
| shape = [m, n] | |||||
| return shape | |||||
| def infer_dtype(self, data1_dtype, data2_dtype): | |||||
| return ms.common.dtype.tensor_type(getattr(ms, "float32")) | |||||
| class CusMatrixCombine(PrimitiveWithInfer): | |||||
| """CusMatMulCube definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init CusMatMulCube""" | |||||
| self.init_prim_io_names(inputs=['x'], outputs=['y']) | |||||
| def get_bprop(self): | |||||
| def bprop(x, out, dout): | |||||
| return (C.zeros_like(x),) | |||||
| return bprop | |||||
| def infer_shape(self, data_shape): | |||||
| a, b, c = data_shape | |||||
| shape = [a * b, a * c] | |||||
| return shape | |||||
| def infer_dtype(self, data_dtype): | |||||
| return data_dtype | |||||
| class CusTranspose02314(PrimitiveWithInfer): | |||||
| """CusTranspose02314 definition""" | |||||
| @prim_attr_register | |||||
| def __init__(self): | |||||
| """init CusTranspose02314""" | |||||
| self.init_prim_io_names(inputs=['x1'], outputs=['y']) | |||||
| def get_bprop(self): | |||||
| def bprop(x, out, dout): | |||||
| return (C.zeros_like(x),) | |||||
| return bprop | |||||
| def infer_shape(self, data1_shape): | |||||
| assert len(data1_shape) == 4 | |||||
| n, c, h, w = data1_shape | |||||
| c0 = 16 | |||||
| c1 = c // 16 | |||||
| shape = (n * h * w, c1 * c0) | |||||
| return shape | |||||
| def infer_dtype(self, data1_dtype): | |||||
| return data1_dtype | |||||