From c7d6997819befc437862d11368ee91d27b9cf584 Mon Sep 17 00:00:00 2001 From: caifubi Date: Mon, 16 Nov 2020 16:24:52 +0800 Subject: [PATCH] pynative host device parallel --- .../gpu/batch_norm_relu_grad_fusion.cc | 6 ++++ .../ccsrc/backend/session/gpu_session.cc | 3 +- mindspore/nn/optim/lars.py | 3 +- mindspore/nn/optim/optimizer.py | 28 +++++++++++-------- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc index e8dc539591..ae904e9e6b 100644 --- a/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc @@ -24,6 +24,7 @@ #include "utils/utils.h" #include "backend/optimizer/common/helper.h" #include "runtime/device/gpu/kernel_info_setter.h" +#include "utils/ms_context.h" namespace mindspore { namespace opt { @@ -41,6 +42,11 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format"); MS_EXCEPTION_IF_NULL(format_attr); auto format = GetValue(format_attr); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + if (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { + return nullptr; + } if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { return nullptr; } diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index b21361e6ad..5ffb82bb1b 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -246,7 +246,8 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, } } if (need_sync) { - if (AnfAlgo::IsParameterWeight(input_node->cast())) { + if (AnfAlgo::IsParameterWeight(input_node->cast()) || + ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode) { tensor->set_device_address(device_address); } MS_EXCEPTION_IF_NULL(device_address); diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index 65187c4cc9..a8b5b05da1 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -88,6 +88,7 @@ class LARS(Optimizer): self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") self.decay_flags = optimizer.decay_flags self.reciprocal_scale = optimizer.reciprocal_scale + self.need_scale = optimizer.need_scale self.hyper_map = C.HyperMap() self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.cast = P.Cast() @@ -133,7 +134,7 @@ class LARS(Optimizer): else: lr = self.learning_rate - if self.reciprocal_scale != 1.0: + if self.need_scale: gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) if self.is_group: diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 54eceabcaa..4b0ff2d627 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -135,18 +135,21 @@ class Optimizer(Cell): if self.is_group: self.parameters = ParameterTuple(self.group_params) self.weight_decay = tuple(self.group_weight_decay) + self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay) decay_filter = lambda x: x > 0 self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) self.exec_weight_decay = any(self.decay_flags) else: self.parameters = ParameterTuple(parameters) self.weight_decay = weight_decay * loss_scale + self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32) decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.exec_weight_decay = self.weight_decay > 0 ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) - self.reciprocal_scale = 1.0 / loss_scale + self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32) + self.need_scale = loss_scale != 1.0 self.param_length = len(self.parameters) self.map_ = C.Map() if context.get_auto_parallel_context("enable_parallel_optimizer"): @@ -215,10 +218,10 @@ class Optimizer(Cell): if self.exec_weight_decay: params = self.parameters if self.is_group: - gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags, + gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags, params, gradients) else: - gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags, + gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags, params, gradients) return gradients @@ -238,7 +241,7 @@ class Optimizer(Cell): tuple[Tensor], The gradients after loss scale. """ - if self.reciprocal_scale != 1.0: + if self.need_scale: gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients) return gradients @@ -522,11 +525,12 @@ class Optimizer(Cell): op_add = P.AddN() op_gather = P.GatherV2() +op_mul = P.Mul() _apply_decay = C.MultitypeFuncGraph("apply_decay") -@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor") +@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor") def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: @@ -537,11 +541,11 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): return gradient -@_apply_decay.register("Number", "Bool", "Tensor", "Tensor") +@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor") def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): """Get grad with weight_decay.""" if if_apply: - return op_add((weight * weight_decay, gradient)) + return op_add((op_mul(weight, weight_decay), gradient)) return gradient @@ -553,14 +557,16 @@ def tensor_grad_scale(scale, grad): """Get grad with scale.""" if scale == 1.0: return grad - return grad * scale + return op_mul(grad, scale) +@_grad_scale.register("Tensor", "Tensor") +def tensor_grad_scale_with_tensor(scale, grad): + """Get grad with scale.""" + return op_mul(grad, scale) -@_grad_scale.register("Number", "RowTensor") +@_grad_scale.register("Tensor", "RowTensor") def tensor_grad_scale_with_sparse(scale, grad): """Get grad with scale.""" - if scale == 1.0: - return grad return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)