Browse Source

!1915 GPU support bert finetune

Merge pull request !1915 from VectorSL/bert-finetune
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
a5b816e67e
8 changed files with 77 additions and 219 deletions
  1. +16
    -1
      mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu
  2. +1
    -0
      mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh
  3. +13
    -0
      mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc
  4. +1
    -1
      mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h
  5. +0
    -33
      mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc
  6. +0
    -171
      mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h
  7. +13
    -2
      model_zoo/bert/finetune.py
  8. +33
    -11
      model_zoo/bert/src/utils.py

+ 16
- 1
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cu View File

@@ -64,6 +64,11 @@ struct SubFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); }
}; };


template <typename T, typename S>
struct AddFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
};

template <> template <>
struct PowerFunc<half, bool> { struct PowerFunc<half, bool> {
// invalid branch // invalid branch
@@ -118,6 +123,9 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const
case BROADCAST_TYPE_SUB: case BROADCAST_TYPE_SUB:
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output); output);
case BROADCAST_TYPE_ADD:
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
} }
} }


@@ -157,6 +165,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const
return NoBroadcastOperator<T, S, MulFunc<T, S>>(nums, input0, input1, output); return NoBroadcastOperator<T, S, MulFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_SUB: case BROADCAST_TYPE_SUB:
return NoBroadcastOperator<T, S, SubFunc<T, S>>(nums, input0, input1, output); return NoBroadcastOperator<T, S, SubFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_ADD:
return NoBroadcastOperator<T, S, AddFunc<T, S>>(nums, input0, input1, output);
} }
} }


@@ -182,7 +192,10 @@ template void Broadcast(const int &l0, const int &l1, const int &l2, const int &
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const half *input0, const half *input1, half *output, enum BroadcastOpType op, const half *input0, const half *input1, half *output,
cudaStream_t stream); cudaStream_t stream);

template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const int *input0, const int *input1, int *output,
cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
bool *output, cudaStream_t stream); bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
@@ -191,3 +204,5 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *
bool *output, cudaStream_t stream); bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1,
half *output, cudaStream_t stream); half *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1,
int *output, cudaStream_t stream);

+ 1
- 0
mindspore/ccsrc/kernel/gpu/cuda_impl/broadcast_impl.cuh View File

@@ -28,6 +28,7 @@ enum BroadcastOpType {
BROADCAST_TYPE_REALDIV = 5, BROADCAST_TYPE_REALDIV = 5,
BROADCAST_TYPE_MUL = 6, BROADCAST_TYPE_MUL = 6,
BROADCAST_TYPE_SUB = 7, BROADCAST_TYPE_SUB = 7,
BROADCAST_TYPE_ADD = 8,
BROADCAST_TYPE_INVALID = 0xffffffff, BROADCAST_TYPE_INVALID = 0xffffffff,
}; };




+ 13
- 0
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.cc View File

@@ -47,6 +47,10 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float) BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)


// fp16 // fp16
MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO(
@@ -77,5 +81,14 @@ MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_TWO( MS_REG_GPU_KERNEL_TWO(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half) BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)

// int32
MS_REG_GPU_KERNEL_TWO(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

+ 1
- 1
mindspore/ccsrc/kernel/gpu/math/broadcast_gpu_kernel.h View File

@@ -98,7 +98,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = { static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD},
}; };


auto iter = kBroadcastTypeMap.find(kernel_name); auto iter = kBroadcastTypeMap.find(kernel_name);


+ 0
- 33
mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.cc View File

@@ -1,33 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "kernel/gpu/math/tensoradd_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TensorAddGpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TensorAddGpuFwdKernel, half)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TensorAddGpuFwdKernel, int)
} // namespace kernel
} // namespace mindspore

+ 0
- 171
mindspore/ccsrc/kernel/gpu/math/tensoradd_gpu_kernel.h View File

@@ -1,171 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_KERNEL_GPU_TENSORADD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_TENSORADD_GPU_KERNEL_H_

#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
template <typename T>
class TensorAddGpuFwdKernel : public GpuKernel {
public:
TensorAddGpuFwdKernel()
: cudnn_handle_(nullptr),
inputA_descriptor_(nullptr),
inputB_descriptor_(nullptr),
opTensor_descriptor_(nullptr),
cudnn_data_type_(CUDNN_DATA_FLOAT),
input_size_(0),
output_size_(0),
workspace_size_(0),
is_null_input_(false) {}
~TensorAddGpuFwdKernel() override { DestroyResource(); }

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *) {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *input_addr2 = GetDeviceAddress<T>(inputs, 1);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
const float alpha = 1;
const float beta = 0;
// A + B = C. [ C = op(alpha1[0] * A, alpha2[0] * B) + beta[0] * C ]
// InputA must match the corresponding dimension of the destination tensor outC, and each dimension of the inputB
// must match the corresponding dimension of outC or must be equal to 1.
if (inputs[0]->size > inputs[1]->size) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnOpTensor(cudnn_handle_, opTensor_descriptor_, &alpha, inputA_descriptor_, input_addr, &alpha,
inputB_descriptor_, input_addr2, &beta, inputA_descriptor_, output_addr),
"cudnnOpTensor Add failed");
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnOpTensor(cudnn_handle_, opTensor_descriptor_, &alpha, inputB_descriptor_, input_addr2, &alpha,
inputA_descriptor_, input_addr, &beta, inputB_descriptor_, output_addr),
"cudnnOpTensor Add failed");
}
return true;
}
bool Init(const CNodePtr &kernel_node) {
InitResource();
cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
cudnn_data_type_ = CUDNN_DATA_FLOAT;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto input_shapeB = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_shapeB);
if (is_null_input_) {
MS_LOG(WARNING) << "TensorAddGpuFwdKernel input is null";
InitSizeLists();
return true;
}
std::vector<int> shapeA;
std::vector<int> shapeB;
std::vector<int> shapeOut;
ShapeNdTo4d(input_shape, &shapeA);
ShapeNdTo4d(input_shapeB, &shapeB);
ShapeNdTo4d(output_shape, &shapeOut);
CheckBroadcast4TensorOp(shapeA, shapeB, shapeOut);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shapeA[0], shapeA[1], shapeA[2], shapeA[3]),
"cudnnSetTensor4dDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputB_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shapeB[0], shapeB[1], shapeB[2], shapeB[3]),
"cudnnSetTensor4dDescriptor failed");

CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetOpTensorDescriptor(opTensor_descriptor_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN),
"cudnnSetOpTensorDescriptor failed");

InitSizeLists();
return true;
}

protected:
void InitResource() {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputA_descriptor_), "cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputB_descriptor_), "cudnnCreateTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateOpTensorDescriptor(&opTensor_descriptor_),
"cudnnCreateOpTensorDescriptor failed");
}
void InitSizeLists() {
if (!is_null_input_) {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_),
"cudnnGetTensorSizeInBytes failed");
input_size_list_.push_back(input_size_);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputB_descriptor_, &output_size_),
"cudnnGetTensorSizeInBytes failed");
}
input_size_list_.push_back(output_size_);

if (output_size_ > input_size_) {
output_size_list_.push_back(output_size_);
} else {
output_size_list_.push_back(input_size_);
}
workspace_size_list_.push_back(workspace_size_);

return;
}

private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputB_descriptor_), "cudnnDestroyTensorDescriptor failed");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(opTensor_descriptor_),
"cudnnDestroyOpTensorDescriptor failed");
}
cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t inputA_descriptor_;
cudnnTensorDescriptor_t inputB_descriptor_;
cudnnOpTensorDescriptor_t opTensor_descriptor_;
cudnnDataType_t cudnn_data_type_;

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

size_t input_size_;
size_t output_size_;
size_t workspace_size_;
bool is_null_input_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_KERNEL_GPU_TENSORADD_GPU_KERNEL_H_

+ 13
- 2
model_zoo/bert/finetune.py View File

@@ -18,6 +18,7 @@ Bert finetune script.
''' '''


import os import os
import argparse
from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell
from src.finetune_config import cfg, bert_net_cfg, tag_to_index from src.finetune_config import cfg, bert_net_cfg, tag_to_index
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
@@ -98,8 +99,14 @@ def test_train():
''' '''
finetune function finetune function
''' '''
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
target = args_opt.device_target
if target == "Ascend":
devid = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid)
elif target == "GPU":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
else:
raise Exception("Target error, GPU or Ascend is supported.")
#BertCLSTrain for classification #BertCLSTrain for classification
#BertNERTrain for sequence labeling #BertNERTrain for sequence labeling
if cfg.task == 'NER': if cfg.task == 'NER':
@@ -151,5 +158,9 @@ def test_train():
model = Model(netwithgrads) model = Model(netwithgrads)
model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb]) model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb])



parser = argparse.ArgumentParser(description='Bert finetune')
parser.add_argument('--device_target', type=str, default='Ascend', help='Device target')
args_opt = parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
test_train() test_train()

+ 33
- 11
model_zoo/bert/src/utils.py View File

@@ -42,6 +42,13 @@ reciprocal = P.Reciprocal()
def tensor_grad_scale(scale, grad): def tensor_grad_scale(scale, grad):
return grad * reciprocal(scale) return grad * reciprocal(scale)


_grad_overflow = C.MultitypeFuncGraph("_grad_overflow")
grad_overflow = P.FloatStatus()

@_grad_overflow.register("Tensor")
def _tensor_grad_overflow(grad):
return grad_overflow(grad)

class BertFinetuneCell(nn.Cell): class BertFinetuneCell(nn.Cell):
""" """
Especifically defined for finetuning where only four inputs tensor are needed. Especifically defined for finetuning where only four inputs tensor are needed.
@@ -67,9 +74,16 @@ class BertFinetuneCell(nn.Cell):
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE)
self.cast = P.Cast() self.cast = P.Cast()
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.gpu_target = False
if context.get_context("device_target") == "GPU":
self.gpu_target = True
self.float_status = P.FloatStatus()
self.addn = P.AddN()
self.reshape = P.Reshape()
else:
self.alloc_status = P.NPUAllocFloatStatus()
self.get_status = P.NPUGetFloatStatus()
self.clear_before_grad = P.NPUClearFloatStatus()
self.reduce_sum = P.ReduceSum(keep_dims=False) self.reduce_sum = P.ReduceSum(keep_dims=False)
self.depend_parameter_use = P.ControlDepend(depend_mode=1) self.depend_parameter_use = P.ControlDepend(depend_mode=1)
self.base = Tensor(1, mstype.float32) self.base = Tensor(1, mstype.float32)
@@ -90,7 +104,7 @@ class BertFinetuneCell(nn.Cell):




weights = self.weights weights = self.weights
init = self.alloc_status()
init = False
loss = self.network(input_ids, loss = self.network(input_ids,
input_mask, input_mask,
token_type_id, token_type_id,
@@ -99,28 +113,36 @@ class BertFinetuneCell(nn.Cell):
scaling_sens = self.loss_scale scaling_sens = self.loss_scale
else: else:
scaling_sens = sens scaling_sens = sens

if not self.gpu_target:
init = self.alloc_status()
clear_before_grad = self.clear_before_grad(init)
F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.grad(self.network, weights)(input_ids, grads = self.grad(self.network, weights)(input_ids,
input_mask, input_mask,
token_type_id, token_type_id,
label_ids, label_ids,
self.cast(scaling_sens, self.cast(scaling_sens,
mstype.float32)) mstype.float32))
clear_before_grad = self.clear_before_grad(init)
F.control_depend(loss, init)
self.depend_parameter_use(clear_before_grad, scaling_sens)
grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads)
grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads)
if self.reducer_flag: if self.reducer_flag:
grads = self.grad_reducer(grads) grads = self.grad_reducer(grads)
flag = self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
if not self.gpu_target:
flag = self.get_status(init)
flag_sum = self.reduce_sum(init, (0,))
F.control_depend(grads, flag)
F.control_depend(flag, flag_sum)
else:
flag_sum = self.hyper_map(F.partial(_grad_overflow), grads)
flag_sum = self.addn(flag_sum)
flag_sum = self.reshape(flag_sum, (()))
if self.is_distributed: if self.is_distributed:
flag_reduce = self.allreduce(flag_sum) flag_reduce = self.allreduce(flag_sum)
cond = self.less_equal(self.base, flag_reduce) cond = self.less_equal(self.base, flag_reduce)
else: else:
cond = self.less_equal(self.base, flag_sum) cond = self.less_equal(self.base, flag_sum)
F.control_depend(grads, flag)
F.control_depend(flag, flag_sum)
overflow = cond overflow = cond
if sens is None: if sens is None:
overflow = self.loss_scaling_manager(self.loss_scale, cond) overflow = self.loss_scaling_manager(self.loss_scale, cond)


Loading…
Cancel
Save