Browse Source

edit loss_scale for gpu

tags/v0.2.0-alpha
VectorSL 5 years ago
parent
commit
2ff6f0de46
3 changed files with 117 additions and 12 deletions
  1. +29
    -11
      mindspore/nn/wrap/loss_scale.py
  2. +5
    -1
      mindspore/ops/operations/__init__.py
  3. +83
    -0
      mindspore/ops/operations/math_ops.py

+ 29
- 11
mindspore/nn/wrap/loss_scale.py View File

@@ -25,6 +25,7 @@ from ...ops import operations as P
from ...ops.operations import NPUGetFloatStatus, NPUAllocFloatStatus, NPUClearFloatStatus, ReduceSum, LessEqual, \
ControlDepend
from ...common import dtype as mstype
import mindspore.context as context

_grad_scale = C.MultitypeFuncGraph("grad_scale")
reciprocal = P.Reciprocal()
@@ -34,6 +35,12 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad):
return grad * F.cast(reciprocal(scale), F.dtype(grad))

_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()

@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)

class DynamicLossScaleUpdateCell(Cell):
r"""
@@ -195,9 +202,15 @@ class TrainOneStepWithLossScaleCell(Cell):
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.hyper_map = C.HyperMap()
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
else:
self.gpu_target = False
self.alloc_status = NPUAllocFloatStatus()
self.get_status = NPUGetFloatStatus()
self.clear_status = NPUClearFloatStatus()
self.reduce_sum = ReduceSum(keep_dims=False)
self.base = Tensor(1, mstype.float32)
self.less_equal = LessEqual()
@@ -222,10 +235,11 @@ class TrainOneStepWithLossScaleCell(Cell):
def construct(self, data, label, sens=None):
weights = self.weights
loss = self.network(data, label)
# init overflow buffer
init = self.alloc_status()
# clear overflow buffer
self.clear_status(init)
if not self.gpu_target:
# init overflow buffer
init = self.alloc_status()
# clear overflow buffer
self.clear_status(init)
if sens is None:
scaling_sens = self.loss_scale
else:
@@ -235,10 +249,14 @@ class TrainOneStepWithLossScaleCell(Cell):
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
# get the overflow buffer
self.get_status(init)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
if not self.gpu_target:
# get the overflow buffer
self.get_status(init)
# sum overflow buffer elements, 0:not overflow , >0:overflow
flag_sum = self.reduce_sum(init, (0,))
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
if self.is_distributed:
# sum overflow flag over devices
flag_reduce = self.allreduce(flag_sum)


+ 5
- 1
mindspore/ops/operations/__init__.py View File

@@ -44,7 +44,7 @@ from .math_ops import (Abs, ACos, AddN, AssignAdd, AssignSub, Atan2, BatchMatMul
LogicalNot, LogicalOr, MatMul, Maximum,
Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus,
NPUGetFloatStatus, Pow, RealDiv,
NPUGetFloatStatus, Pow, RealDiv, IsNan, IsInf, IsFinite, FloatStatus,
Reciprocal, CumSum,
Sin, Sqrt, Rsqrt,
Square, Sub, TensorAdd, Sign, Round)
@@ -151,6 +151,10 @@ __all__ = [
'Neg',
'Slice',
'DType',
'IsNan',
'IsInf',
'IsFinite',
'FloatStatus',
'NPUAllocFloatStatus',
'NPUGetFloatStatus',
'NPUClearFloatStatus',


+ 83
- 0
mindspore/ops/operations/math_ops.py View File

@@ -1557,6 +1557,89 @@ class LogicalOr(_LogicBinaryOp):
def infer_dtype(self, x_dtype, y_dtype):
return _LogicBinaryOp.do_infer_dtype(x_dtype, y_dtype, (mstype.bool_,))

class IsNan(PrimitiveWithInfer):
"""
Judging which elements are nan for each position
Inputs:
- **input_x** (Tensor) - The input tensor.

Outputs:
Tensor, has the same shape of input.
"""

@prim_attr_register
def __init__(self):
"""init IsNan"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
return mstype.bool_

class IsInf(PrimitiveWithInfer):
"""
Judging which elements are inf or -inf for each position
Inputs:
- **input_x** (Tensor) - The input tensor.

Outputs:
Tensor, has the same shape of input.
"""

@prim_attr_register
def __init__(self):
"""init IsInf"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
return mstype.bool_

class IsFinite(PrimitiveWithInfer):
"""
Judging which elements are finite for each position
Inputs:
- **input_x** (Tensor) - The input tensor.

Outputs:
Tensor, has the same shape of input.
"""

@prim_attr_register
def __init__(self):
"""init IsFinite"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])

def infer_shape(self, x_shape):
return x_shape

def infer_dtype(self, x_dtype):
return mstype.bool_

class FloatStatus(PrimitiveWithInfer):
"""
Determine if the elements contains nan, inf or -inf
Inputs:
- **input_x** (Tensor) - The input tensor.

Outputs:
Tensor, has the shape of `(1,)`.
"""

@prim_attr_register
def __init__(self):
"""init FloatStatus"""
self.init_prim_io_names(inputs=['x'], outputs=['output'])

def infer_shape(self, x_shape):
return [1]

def infer_dtype(self, x_dtype):
return x_dtype

class NPUAllocFloatStatus(PrimitiveWithInfer):
"""


Loading…
Cancel
Save