Browse Source

pynative host device parallel

tags/v1.1.0
caifubi 5 years ago
parent
commit
c7d6997819
4 changed files with 27 additions and 13 deletions
  1. +6
    -0
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc
  2. +2
    -1
      mindspore/ccsrc/backend/session/gpu_session.cc
  3. +2
    -1
      mindspore/nn/optim/lars.py
  4. +17
    -11
      mindspore/nn/optim/optimizer.py

+ 6
- 0
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc View File

@@ -24,6 +24,7 @@
#include "utils/utils.h" #include "utils/utils.h"
#include "backend/optimizer/common/helper.h" #include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h" #include "runtime/device/gpu/kernel_info_setter.h"
#include "utils/ms_context.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@@ -41,6 +42,11 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format"); auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr); MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr); auto format = GetValue<std::string>(format_attr);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return nullptr;
}
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") { if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr; return nullptr;
} }


+ 2
- 1
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -246,7 +246,8 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
} }
} }
if (need_sync) { if (need_sync) {
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) ||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
tensor->set_device_address(device_address); tensor->set_device_address(device_address);
} }
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);


+ 2
- 1
mindspore/nn/optim/lars.py View File

@@ -88,6 +88,7 @@ class LARS(Optimizer):
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr") self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.decay_flags = optimizer.decay_flags self.decay_flags = optimizer.decay_flags
self.reciprocal_scale = optimizer.reciprocal_scale self.reciprocal_scale = optimizer.reciprocal_scale
self.need_scale = optimizer.need_scale
self.hyper_map = C.HyperMap() self.hyper_map = C.HyperMap()
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip) self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
self.cast = P.Cast() self.cast = P.Cast()
@@ -133,7 +134,7 @@ class LARS(Optimizer):
else: else:
lr = self.learning_rate lr = self.learning_rate


if self.reciprocal_scale != 1.0:
if self.need_scale:
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients) gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)


if self.is_group: if self.is_group:


+ 17
- 11
mindspore/nn/optim/optimizer.py View File

@@ -135,18 +135,21 @@ class Optimizer(Cell):
if self.is_group: if self.is_group:
self.parameters = ParameterTuple(self.group_params) self.parameters = ParameterTuple(self.group_params)
self.weight_decay = tuple(self.group_weight_decay) self.weight_decay = tuple(self.group_weight_decay)
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
decay_filter = lambda x: x > 0 decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay) self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
self.exec_weight_decay = any(self.decay_flags) self.exec_weight_decay = any(self.decay_flags)
else: else:
self.parameters = ParameterTuple(parameters) self.parameters = ParameterTuple(parameters)
self.weight_decay = weight_decay * loss_scale self.weight_decay = weight_decay * loss_scale
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.exec_weight_decay = self.weight_decay > 0 self.exec_weight_decay = self.weight_decay > 0
ps_filter = lambda x: x.is_param_ps ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.need_scale = loss_scale != 1.0
self.param_length = len(self.parameters) self.param_length = len(self.parameters)
self.map_ = C.Map() self.map_ = C.Map()
if context.get_auto_parallel_context("enable_parallel_optimizer"): if context.get_auto_parallel_context("enable_parallel_optimizer"):
@@ -215,10 +218,10 @@ class Optimizer(Cell):
if self.exec_weight_decay: if self.exec_weight_decay:
params = self.parameters params = self.parameters
if self.is_group: if self.is_group:
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
params, gradients) params, gradients)
else: else:
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
params, gradients) params, gradients)


return gradients return gradients
@@ -238,7 +241,7 @@ class Optimizer(Cell):
tuple[Tensor], The gradients after loss scale. tuple[Tensor], The gradients after loss scale.


""" """
if self.reciprocal_scale != 1.0:
if self.need_scale:
gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients) gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)


return gradients return gradients
@@ -522,11 +525,12 @@ class Optimizer(Cell):


op_add = P.AddN() op_add = P.AddN()
op_gather = P.GatherV2() op_gather = P.GatherV2()
op_mul = P.Mul()


_apply_decay = C.MultitypeFuncGraph("apply_decay") _apply_decay = C.MultitypeFuncGraph("apply_decay")




@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor")
@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient): def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay.""" """Get grad with weight_decay."""
if if_apply: if if_apply:
@@ -537,11 +541,11 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
return gradient return gradient




@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient): def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay.""" """Get grad with weight_decay."""
if if_apply: if if_apply:
return op_add((weight * weight_decay, gradient))
return op_add((op_mul(weight, weight_decay), gradient))
return gradient return gradient




@@ -553,14 +557,16 @@ def tensor_grad_scale(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
if scale == 1.0: if scale == 1.0:
return grad return grad
return grad * scale
return op_mul(grad, scale)


@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale_with_tensor(scale, grad):
"""Get grad with scale."""
return op_mul(grad, scale)


@_grad_scale.register("Number", "RowTensor")
@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_with_sparse(scale, grad): def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale.""" """Get grad with scale."""
if scale == 1.0:
return grad
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape) return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)






Loading…
Cancel
Save