| @@ -18,8 +18,8 @@ from __future__ import division | |||
| import os | |||
| import numpy as np | |||
| from PIL import Image | |||
| from matplotlib.colors import rgb_to_hsv, hsv_to_rgb | |||
| from PIL import Image | |||
| import mindspore.dataset as de | |||
| from mindspore.mindrecord import FileWriter | |||
| import mindspore.dataset.transforms.vision.c_transforms as C | |||
| @@ -16,6 +16,9 @@ | |||
| from __future__ import absolute_import as _abs | |||
| import sys | |||
| import os | |||
| from .op_build import op_build | |||
| from .message import compilewithjson | |||
| def AKGAddPath(): | |||
| """_akg add path.""" | |||
| @@ -58,6 +61,3 @@ class AKGMetaPathLoader: | |||
| sys.meta_path.insert(0, AKGMetaPathFinder()) | |||
| from .op_build import op_build | |||
| from .message import compilewithjson | |||
| @@ -14,7 +14,6 @@ | |||
| # ============================================================================ | |||
| """FTRL""" | |||
| from mindspore.ops import functional as F, composite as C, operations as P | |||
| from mindspore.common.initializer import initializer | |||
| from mindspore.common.parameter import Parameter | |||
| from mindspore.common import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| @@ -23,6 +22,8 @@ from mindspore._checkparam import Rel | |||
| from .optimizer import Optimizer, apply_decay, grad_scale | |||
| ftrl_opt = C.MultitypeFuncGraph("ftrl_opt") | |||
| @ftrl_opt.register("Function", "Tensor", "Number", "Number", "Number", "Tensor", "Tensor", "Tensor", "Tensor") | |||
| def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weight, moment): | |||
| """Apply ftrl optimizer to the weight parameter.""" | |||
| @@ -30,8 +31,10 @@ def _tensor_run_opt(opt, learning_rate, l1, l2, lr_power, linear, gradient, weig | |||
| success = F.depend(success, opt(weight, moment, linear, gradient, learning_rate, l1, l2, lr_power)) | |||
| return success | |||
| def _check_param(initial_accum, learning_rate, lr_power, l1, l2, use_locking, loss_scale=1.0, weight_decay=0.0, | |||
| prim_name=None): | |||
| """Check param.""" | |||
| validator.check_value_type("initial_accum", initial_accum, [float], prim_name) | |||
| validator.check_number("initial_accum", initial_accum, 0.0, Rel.GE, prim_name) | |||
| @@ -104,7 +107,7 @@ class FTRL(Optimizer): | |||
| self.lr_power = lr_power | |||
| self.reciprocal_scale = 1.0 / loss_scale | |||
| self.weight_decay = weight_decay | |||
| self.decay_tf = tuple((lambda:True)() for x in self.parameters) | |||
| self.decay_tf = tuple((lambda: True)() for x in self.parameters) | |||
| self.hyper_map = C.HyperMap() | |||
| self.opt = P.ApplyFtrl(use_locking=use_locking) | |||
| self.one = Tensor(1, mstype.int32) | |||
| @@ -118,5 +121,6 @@ class FTRL(Optimizer): | |||
| if self.reciprocal_scale != 1.0: | |||
| grads = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), grads) | |||
| lr = self.learning_rate | |||
| success = self.hyper_map(F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power), linear, grads, params, moments) | |||
| success = self.hyper_map(F.partial(ftrl_opt, self.opt, lr, self.l1, self.l2, self.lr_power), | |||
| linear, grads, params, moments) | |||
| return success | |||
| @@ -2063,7 +2063,7 @@ class LSTM(PrimitiveWithInfer): | |||
| return (y_shape, h_shape, c_shape, reserved_shape, state_shape) | |||
| def infer_dtype(self, x_dtype, h_dtype, c_dtype, w_dtype): | |||
| args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} | |||
| args = {'x': x_dtype, 'h': h_dtype, 'c': c_dtype, 'w': w_dtype} | |||
| validator.check_tensor_type_same(args, (mstype.float32, mstype.float16), self.name) | |||
| return (x_dtype, x_dtype, x_dtype, x_dtype, x_dtype) | |||
| @@ -2691,8 +2691,8 @@ class ConfusionMulGrad(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, axis = (), keep_dims = False): | |||
| self.init_prim_io_names(inputs = ["input0", "input1", "input2"], outputs = ["output0", "output1"]) | |||
| def __init__(self, axis=(), keep_dims=False): | |||
| self.init_prim_io_names(inputs=["input0", "input1", "input2"], outputs=["output0", "output1"]) | |||
| self.axis_ = validator.check_value_type("axis", axis, [int, tuple, list], self.name) | |||
| self.keep_dims_ = validator.check_value_type("keep_dims", keep_dims, [bool], self.name) | |||
| @@ -41,6 +41,7 @@ class OutputTo16(nn.Cell): | |||
| def _do_keep_batchnorm_fp32(network): | |||
| """Do keep batchnorm fp32.""" | |||
| cells = network.name_cells() | |||
| change = False | |||
| for name in cells: | |||
| @@ -68,6 +69,7 @@ _config_level = { | |||
| def _check_kwargs(key_words): | |||
| """Check kwargs.""" | |||
| for arg in key_words: | |||
| if arg not in ['cast_model_type', 'keep_batchnorm_fp32', 'loss_scale_manager']: | |||
| raise ValueError(f"Unsupported arg '{arg}'") | |||
| @@ -84,6 +86,7 @@ def _check_kwargs(key_words): | |||
| def _add_loss_network(network, loss_fn, cast_model_type): | |||
| """Add loss network.""" | |||
| class WithLossCell(nn.Cell): | |||
| "Wrap loss for amp. Cast network output back to float32" | |||
| @@ -683,13 +683,14 @@ class LossMonitor(Callback): | |||
| class TimeMonitor(Callback): | |||
| """Time Monitor.""" | |||
| def __init__(self, data_size): | |||
| super(TimeMonitor, self).__init__() | |||
| self.data_size = data_size | |||
| def epoch_begin(self, run_context): | |||
| self.epoch_time = time.time() | |||
| def epoch_end(self, run_context): | |||
| epoch_mseconds = (time.time() - self.epoch_time) * 1000 | |||
| per_step_mseconds = epoch_mseconds / self.data_size | |||
| @@ -701,4 +702,3 @@ class TimeMonitor(Callback): | |||
| def step_end(self, run_context): | |||
| step_mseconds = (time.time() - self.step_time) * 1000 | |||
| print('step time', step_mseconds, flush=True) | |||
| @@ -122,7 +122,7 @@ class Model: | |||
| def _check_kwargs(self, kwargs): | |||
| for arg in kwargs: | |||
| if arg not in ['loss_scale_manager', 'keep_batchnorm_fp32']: | |||
| raise ValueError(f"Unsupport arg '{arg}'") | |||
| raise ValueError(f"Unsupport arg '{arg}'") | |||
| def _build_train_network(self): | |||
| """Build train network""" | |||
| @@ -130,17 +130,17 @@ class Model: | |||
| if self._optimizer: | |||
| if self._loss_scale_manager_set: | |||
| network = amp.build_train_network(network, | |||
| self._optimizer, | |||
| self._loss_fn, | |||
| level=self._amp_level, | |||
| loss_scale_manager=self._loss_scale_manager, | |||
| keep_batchnorm_fp32=self._keep_bn_fp32) | |||
| self._optimizer, | |||
| self._loss_fn, | |||
| level=self._amp_level, | |||
| loss_scale_manager=self._loss_scale_manager, | |||
| keep_batchnorm_fp32=self._keep_bn_fp32) | |||
| else: | |||
| network = amp.build_train_network(network, | |||
| self._optimizer, | |||
| self._loss_fn, | |||
| level=self._amp_level, | |||
| keep_batchnorm_fp32=self._keep_bn_fp32) | |||
| self._optimizer, | |||
| self._loss_fn, | |||
| level=self._amp_level, | |||
| keep_batchnorm_fp32=self._keep_bn_fp32) | |||
| elif self._loss_fn: | |||
| network = nn.WithLossCell(network, self._loss_fn) | |||
| # If need to check if loss_fn is not None, but optimizer is None | |||
| @@ -273,14 +273,14 @@ class Model: | |||
| # remove later to deal with loop sink | |||
| need_wrap = False | |||
| if not hasattr(train_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ | |||
| and not context.get_context("enable_ge"): | |||
| and not context.get_context("enable_ge"): | |||
| need_wrap = True | |||
| dataset_helper = DatasetHelper(train_dataset) | |||
| # remove later to deal with loop sink | |||
| if need_wrap: | |||
| self._train_network = nn.DataWrapper(self._train_network, *(dataset_helper.types_shapes()), | |||
| train_dataset.__ME_INITED__) | |||
| train_dataset.__ME_INITED__) | |||
| cb_params.train_network = self._train_network | |||
| self._train_network.set_train() | |||
| @@ -440,7 +440,7 @@ class Model: | |||
| # remove later to deal with loop sink | |||
| need_wrap = False | |||
| if not hasattr(valid_dataset, '__ME_INITED__') and context.get_context("enable_loop_sink") \ | |||
| and not context.get_context("enable_ge"): | |||
| and not context.get_context("enable_ge"): | |||
| need_wrap = True | |||
| valid_dataset.__loop_size__ = 1 | |||
| @@ -449,7 +449,7 @@ class Model: | |||
| # remove later to deal with loop sink | |||
| if need_wrap: | |||
| self._eval_network = nn.DataWrapper(self._eval_network, *(dataset_helper.types_shapes()), | |||
| valid_dataset.__ME_INITED__) | |||
| valid_dataset.__ME_INITED__) | |||
| self._eval_network.set_train(mode=False) | |||
| self._eval_network.phase = 'eval' | |||
| @@ -174,8 +174,7 @@ test_sets = [ | |||
| embedding_shape=[1, 128, 768], | |||
| use_one_hot_embeddings=True, | |||
| initializer_range=0.02), 1, 1), { | |||
| 'init_param_with': lambda shp: np.ones(shp).astype(np.float32) | |||
| }), | |||
| 'init_param_with': lambda shp: np.ones(shp).astype(np.float32)}), | |||
| 'desc_inputs': [input_ids], | |||
| 'desc_bprop': [[128]]}), | |||
| ('EmbeddingLookup_multi_outputs_init_param', { | |||
| @@ -184,8 +183,7 @@ test_sets = [ | |||
| embedding_shape=[1, 128, 768], | |||
| use_one_hot_embeddings=False, | |||
| initializer_range=0.02), { | |||
| 'init_param_with': lambda shp: np.ones(shp).astype(np.float32) | |||
| }), | |||
| 'init_param_with': lambda shp: np.ones(shp).astype(np.float32)}), | |||
| 'desc_inputs': [input_ids], | |||
| 'desc_bprop': [[1, 128, 768], [128]]}), | |||
| ('EmbeddingLookup_multi_outputs_grad_with_no_sens', { | |||
| @@ -194,8 +192,7 @@ test_sets = [ | |||
| embedding_shape=[1, 128, 768], | |||
| use_one_hot_embeddings=False, | |||
| initializer_range=0.02), { | |||
| 'init_param_with': lambda shp: np.ones(shp).astype(np.float32) | |||
| }), | |||
| 'init_param_with': lambda shp: np.ones(shp).astype(np.float32)}), | |||
| 'desc_inputs': [input_ids]}), | |||
| ('GetMaskedLMOutput_grad_with_no_sens', { | |||
| 'block': GetMaskedLMOutput(BertConfig(batch_size=1)), | |||
| @@ -44,4 +44,4 @@ class CheckExceptionsEC(IExectorComponent): | |||
| raise Exception(f"Expect {e}, but got {sys.exc_info()[0]}") | |||
| if error_kws and any(keyword not in str(exec_info.value) for keyword in error_kws): | |||
| raise ValueError('Error message `{}` does not contain all keywords `{}`'.format( | |||
| str(exec_info.value), error_kws)) | |||
| str(exec_info.value), error_kws)) | |||