Browse Source

pynative host device parallel

tags/v1.1.0
caifubi 5 years ago
parent
commit
c7d6997819
4 changed files with 27 additions and 13 deletions
  1. +6
    -0
      mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc
  2. +2
    -1
      mindspore/ccsrc/backend/session/gpu_session.cc
  3. +2
    -1
      mindspore/nn/optim/lars.py
  4. +17
    -11
      mindspore/nn/optim/optimizer.py

+ 6
- 0
mindspore/ccsrc/backend/optimizer/gpu/batch_norm_relu_grad_fusion.cc View File

@@ -24,6 +24,7 @@
#include "utils/utils.h"
#include "backend/optimizer/common/helper.h"
#include "runtime/device/gpu/kernel_info_setter.h"
#include "utils/ms_context.h"
namespace mindspore {
namespace opt {
@@ -41,6 +42,11 @@ const AnfNodePtr BatchNormReluGradFusion::Process(const FuncGraphPtr &graph, con
auto format_attr = AnfAlgo::GetCNodePrimitive(node)->GetAttr("data_format");
MS_EXCEPTION_IF_NULL(format_attr);
auto format = GetValue<std::string>(format_attr);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
return nullptr;
}
if (AnfAlgo::GetInputFormat(node, 0) != kOpFormat_NHWC && format != "NHWC") {
return nullptr;
}


+ 2
- 1
mindspore/ccsrc/backend/session/gpu_session.cc View File

@@ -246,7 +246,8 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
}
}
if (need_sync) {
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>())) {
if (AnfAlgo::IsParameterWeight(input_node->cast<ParameterPtr>()) ||
ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode) {
tensor->set_device_address(device_address);
}
MS_EXCEPTION_IF_NULL(device_address);


+ 2
- 1
mindspore/nn/optim/lars.py View File

@@ -88,6 +88,7 @@ class LARS(Optimizer):
self.learning_rate = Parameter(Tensor(0.0, dtype=mstype.float32), name="fake_lr")
self.decay_flags = optimizer.decay_flags
self.reciprocal_scale = optimizer.reciprocal_scale
self.need_scale = optimizer.need_scale
self.hyper_map = C.HyperMap()
self.lars = P.LARSUpdate(epsilon, coefficient, use_clip)
self.cast = P.Cast()
@@ -133,7 +134,7 @@ class LARS(Optimizer):
else:
lr = self.learning_rate

if self.reciprocal_scale != 1.0:
if self.need_scale:
gradients = self.hyper_map(F.partial(_grad_scale, self.reciprocal_scale), gradients)

if self.is_group:


+ 17
- 11
mindspore/nn/optim/optimizer.py View File

@@ -135,18 +135,21 @@ class Optimizer(Cell):
if self.is_group:
self.parameters = ParameterTuple(self.group_params)
self.weight_decay = tuple(self.group_weight_decay)
self.weight_decay_tensor_tuple = tuple(Tensor(x, mstype.float32) for x in self.group_weight_decay)
decay_filter = lambda x: x > 0
self.decay_flags = tuple(decay_filter(x) for x in self.weight_decay)
self.exec_weight_decay = any(self.decay_flags)
else:
self.parameters = ParameterTuple(parameters)
self.weight_decay = weight_decay * loss_scale
self.weight_decay_tensor = Tensor(self.weight_decay, mstype.float32)
decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name
self.decay_flags = tuple(decay_filter(x) for x in self.parameters)
self.exec_weight_decay = self.weight_decay > 0
ps_filter = lambda x: x.is_param_ps
self.ps_parameters = tuple(ps_filter(x) for x in self.parameters)
self.reciprocal_scale = 1.0 / loss_scale
self.reciprocal_scale = Tensor(1.0 / loss_scale, mstype.float32)
self.need_scale = loss_scale != 1.0
self.param_length = len(self.parameters)
self.map_ = C.Map()
if context.get_auto_parallel_context("enable_parallel_optimizer"):
@@ -215,10 +218,10 @@ class Optimizer(Cell):
if self.exec_weight_decay:
params = self.parameters
if self.is_group:
gradients = self.map_(F.partial(_apply_decay), self.weight_decay, self.decay_flags,
gradients = self.map_(F.partial(_apply_decay), self.weight_decay_tensor_tuple, self.decay_flags,
params, gradients)
else:
gradients = self.map_(F.partial(_apply_decay, self.weight_decay), self.decay_flags,
gradients = self.map_(F.partial(_apply_decay, self.weight_decay_tensor), self.decay_flags,
params, gradients)

return gradients
@@ -238,7 +241,7 @@ class Optimizer(Cell):
tuple[Tensor], The gradients after loss scale.

"""
if self.reciprocal_scale != 1.0:
if self.need_scale:
gradients = self.map_(F.partial(_grad_scale, self.reciprocal_scale), gradients)

return gradients
@@ -522,11 +525,12 @@ class Optimizer(Cell):

op_add = P.AddN()
op_gather = P.GatherV2()
op_mul = P.Mul()

_apply_decay = C.MultitypeFuncGraph("apply_decay")


@_apply_decay.register("Number", "Bool", "Tensor", "RowTensor")
@_apply_decay.register("Tensor", "Bool", "Tensor", "RowTensor")
def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
@@ -537,11 +541,11 @@ def _tensor_apply_decay_with_sparse(weight_decay, if_apply, weight, gradient):
return gradient


@_apply_decay.register("Number", "Bool", "Tensor", "Tensor")
@_apply_decay.register("Tensor", "Bool", "Tensor", "Tensor")
def _tensor_apply_decay(weight_decay, if_apply, weight, gradient):
"""Get grad with weight_decay."""
if if_apply:
return op_add((weight * weight_decay, gradient))
return op_add((op_mul(weight, weight_decay), gradient))
return gradient


@@ -553,14 +557,16 @@ def tensor_grad_scale(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
return grad * scale
return op_mul(grad, scale)

@_grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale_with_tensor(scale, grad):
"""Get grad with scale."""
return op_mul(grad, scale)

@_grad_scale.register("Number", "RowTensor")
@_grad_scale.register("Tensor", "RowTensor")
def tensor_grad_scale_with_sparse(scale, grad):
"""Get grad with scale."""
if scale == 1.0:
return grad
return RowTensor(grad.indices, grad.values * scale, grad.dense_shape)




Loading…
Cancel
Save