Merge pull request !1915 from VectorSL/bert-finetunetags/v0.5.0-beta
| @@ -64,6 +64,11 @@ struct SubFunc { | |||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } | |||
| }; | |||
| template <typename T, typename S> | |||
| struct AddFunc { | |||
| __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } | |||
| }; | |||
| template <> | |||
| struct PowerFunc<half, bool> { | |||
| // invalid branch | |||
| @@ -118,6 +123,9 @@ __global__ void BroadcastKernel(const int l0, const int l1, const int l2, const | |||
| case BROADCAST_TYPE_SUB: | |||
| return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| case BROADCAST_TYPE_ADD: | |||
| return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1, | |||
| output); | |||
| } | |||
| } | |||
| @@ -157,6 +165,8 @@ __global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const | |||
| return NoBroadcastOperator<T, S, MulFunc<T, S>>(nums, input0, input1, output); | |||
| case BROADCAST_TYPE_SUB: | |||
| return NoBroadcastOperator<T, S, SubFunc<T, S>>(nums, input0, input1, output); | |||
| case BROADCAST_TYPE_ADD: | |||
| return NoBroadcastOperator<T, S, AddFunc<T, S>>(nums, input0, input1, output); | |||
| } | |||
| } | |||
| @@ -182,7 +192,10 @@ template void Broadcast(const int &l0, const int &l1, const int &l2, const int & | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastOpType op, const half *input0, const half *input1, half *output, | |||
| cudaStream_t stream); | |||
| template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, | |||
| const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, | |||
| enum BroadcastOpType op, const int *input0, const int *input1, int *output, | |||
| cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, | |||
| bool *output, cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, | |||
| @@ -191,3 +204,5 @@ template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half * | |||
| bool *output, cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, | |||
| half *output, cudaStream_t stream); | |||
| template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, | |||
| int *output, cudaStream_t stream); | |||
| @@ -28,6 +28,7 @@ enum BroadcastOpType { | |||
| BROADCAST_TYPE_REALDIV = 5, | |||
| BROADCAST_TYPE_MUL = 6, | |||
| BROADCAST_TYPE_SUB = 7, | |||
| BROADCAST_TYPE_ADD = 8, | |||
| BROADCAST_TYPE_INVALID = 0xffffffff, | |||
| }; | |||
| @@ -47,6 +47,10 @@ MS_REG_GPU_KERNEL_TWO( | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGpuKernel, float, float) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| TensorAdd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| BroadcastOpGpuKernel, float, float) | |||
| // fp16 | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| @@ -77,5 +81,14 @@ MS_REG_GPU_KERNEL_TWO( | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| BroadcastOpGpuKernel, half, half) | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| TensorAdd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| BroadcastOpGpuKernel, half, half) | |||
| // int32 | |||
| MS_REG_GPU_KERNEL_TWO( | |||
| TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| BroadcastOpGpuKernel, int, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -98,7 +98,7 @@ class BroadcastOpGpuKernel : public GpuKernel { | |||
| static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = { | |||
| {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, | |||
| {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, | |||
| {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, | |||
| {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, | |||
| }; | |||
| auto iter = kBroadcastTypeMap.find(kernel_name); | |||
| @@ -1,33 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "kernel/gpu/math/tensoradd_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| TensorAdd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| TensorAddGpuFwdKernel, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| TensorAdd, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| TensorAddGpuFwdKernel, half) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| TensorAddGpuFwdKernel, int) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -1,171 +0,0 @@ | |||
| /** | |||
| * Copyright 2019 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_KERNEL_GPU_TENSORADD_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_KERNEL_GPU_TENSORADD_GPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "kernel/gpu/gpu_kernel.h" | |||
| #include "kernel/gpu/gpu_kernel_factory.h" | |||
| #include "kernel/gpu/kernel_constants.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class TensorAddGpuFwdKernel : public GpuKernel { | |||
| public: | |||
| TensorAddGpuFwdKernel() | |||
| : cudnn_handle_(nullptr), | |||
| inputA_descriptor_(nullptr), | |||
| inputB_descriptor_(nullptr), | |||
| opTensor_descriptor_(nullptr), | |||
| cudnn_data_type_(CUDNN_DATA_FLOAT), | |||
| input_size_(0), | |||
| output_size_(0), | |||
| workspace_size_(0), | |||
| is_null_input_(false) {} | |||
| ~TensorAddGpuFwdKernel() override { DestroyResource(); } | |||
| const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } | |||
| const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } | |||
| const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, | |||
| const std::vector<AddressPtr> &outputs, void *) { | |||
| if (is_null_input_) { | |||
| return true; | |||
| } | |||
| T *input_addr = GetDeviceAddress<T>(inputs, 0); | |||
| T *input_addr2 = GetDeviceAddress<T>(inputs, 1); | |||
| T *output_addr = GetDeviceAddress<T>(outputs, 0); | |||
| const float alpha = 1; | |||
| const float beta = 0; | |||
| // A + B = C. [ C = op(alpha1[0] * A, alpha2[0] * B) + beta[0] * C ] | |||
| // InputA must match the corresponding dimension of the destination tensor outC, and each dimension of the inputB | |||
| // must match the corresponding dimension of outC or must be equal to 1. | |||
| if (inputs[0]->size > inputs[1]->size) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnOpTensor(cudnn_handle_, opTensor_descriptor_, &alpha, inputA_descriptor_, input_addr, &alpha, | |||
| inputB_descriptor_, input_addr2, &beta, inputA_descriptor_, output_addr), | |||
| "cudnnOpTensor Add failed"); | |||
| } else { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnOpTensor(cudnn_handle_, opTensor_descriptor_, &alpha, inputB_descriptor_, input_addr2, &alpha, | |||
| inputA_descriptor_, input_addr, &beta, inputB_descriptor_, output_addr), | |||
| "cudnnOpTensor Add failed"); | |||
| } | |||
| return true; | |||
| } | |||
| bool Init(const CNodePtr &kernel_node) { | |||
| InitResource(); | |||
| cudnn_data_type_ = kCudnnDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; | |||
| if (cudnn_data_type_ == CUDNN_DATA_INT32) { | |||
| cudnn_data_type_ = CUDNN_DATA_FLOAT; | |||
| } | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(ERROR) << "Input number is " << input_num << ", but cudnnAddTensor needs 2 inputs."; | |||
| return false; | |||
| } | |||
| size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); | |||
| if (output_num != 1) { | |||
| MS_LOG(ERROR) << "Output number is " << output_num << ", but cudnnAddTensor needs 1 output."; | |||
| return false; | |||
| } | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto input_shapeB = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| is_null_input_ = CHECK_NULL_INPUT(input_shape) || CHECK_NULL_INPUT(input_shapeB); | |||
| if (is_null_input_) { | |||
| MS_LOG(WARNING) << "TensorAddGpuFwdKernel input is null"; | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| std::vector<int> shapeA; | |||
| std::vector<int> shapeB; | |||
| std::vector<int> shapeOut; | |||
| ShapeNdTo4d(input_shape, &shapeA); | |||
| ShapeNdTo4d(input_shapeB, &shapeB); | |||
| ShapeNdTo4d(output_shape, &shapeOut); | |||
| CheckBroadcast4TensorOp(shapeA, shapeB, shapeOut); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | |||
| shapeA[0], shapeA[1], shapeA[2], shapeA[3]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputB_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_, | |||
| shapeB[0], shapeB[1], shapeB[2], shapeB[3]), | |||
| "cudnnSetTensor4dDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT( | |||
| cudnnSetOpTensorDescriptor(opTensor_descriptor_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_NOT_PROPAGATE_NAN), | |||
| "cudnnSetOpTensorDescriptor failed"); | |||
| InitSizeLists(); | |||
| return true; | |||
| } | |||
| protected: | |||
| void InitResource() { | |||
| cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputA_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&inputB_descriptor_), "cudnnCreateTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateOpTensorDescriptor(&opTensor_descriptor_), | |||
| "cudnnCreateOpTensorDescriptor failed"); | |||
| } | |||
| void InitSizeLists() { | |||
| if (!is_null_input_) { | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputA_descriptor_, &input_size_), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| input_size_list_.push_back(input_size_); | |||
| CHECK_CUDNN_RET_WITH_EXCEPT(cudnnGetTensorSizeInBytes(inputB_descriptor_, &output_size_), | |||
| "cudnnGetTensorSizeInBytes failed"); | |||
| } | |||
| input_size_list_.push_back(output_size_); | |||
| if (output_size_ > input_size_) { | |||
| output_size_list_.push_back(output_size_); | |||
| } else { | |||
| output_size_list_.push_back(input_size_); | |||
| } | |||
| workspace_size_list_.push_back(workspace_size_); | |||
| return; | |||
| } | |||
| private: | |||
| void DestroyResource() noexcept { | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputA_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(inputB_descriptor_), "cudnnDestroyTensorDescriptor failed"); | |||
| CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyOpTensorDescriptor(opTensor_descriptor_), | |||
| "cudnnDestroyOpTensorDescriptor failed"); | |||
| } | |||
| cudnnHandle_t cudnn_handle_; | |||
| cudnnTensorDescriptor_t inputA_descriptor_; | |||
| cudnnTensorDescriptor_t inputB_descriptor_; | |||
| cudnnOpTensorDescriptor_t opTensor_descriptor_; | |||
| cudnnDataType_t cudnn_data_type_; | |||
| std::vector<size_t> input_size_list_; | |||
| std::vector<size_t> output_size_list_; | |||
| std::vector<size_t> workspace_size_list_; | |||
| size_t input_size_; | |||
| size_t output_size_; | |||
| size_t workspace_size_; | |||
| bool is_null_input_; | |||
| }; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_KERNEL_GPU_TENSORADD_GPU_KERNEL_H_ | |||
| @@ -18,6 +18,7 @@ Bert finetune script. | |||
| ''' | |||
| import os | |||
| import argparse | |||
| from src.utils import BertFinetuneCell, BertCLS, BertNER, BertSquad, BertSquadCell | |||
| from src.finetune_config import cfg, bert_net_cfg, tag_to_index | |||
| import mindspore.common.dtype as mstype | |||
| @@ -98,8 +99,14 @@ def test_train(): | |||
| ''' | |||
| finetune function | |||
| ''' | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) | |||
| target = args_opt.device_target | |||
| if target == "Ascend": | |||
| devid = int(os.getenv('DEVICE_ID')) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=devid) | |||
| elif target == "GPU": | |||
| context.set_context(mode=context.GRAPH_MODE, device_target="GPU") | |||
| else: | |||
| raise Exception("Target error, GPU or Ascend is supported.") | |||
| #BertCLSTrain for classification | |||
| #BertNERTrain for sequence labeling | |||
| if cfg.task == 'NER': | |||
| @@ -151,5 +158,9 @@ def test_train(): | |||
| model = Model(netwithgrads) | |||
| model.train(cfg.epoch_num, dataset, callbacks=[LossCallBack(), ckpoint_cb]) | |||
| parser = argparse.ArgumentParser(description='Bert finetune') | |||
| parser.add_argument('--device_target', type=str, default='Ascend', help='Device target') | |||
| args_opt = parser.parse_args() | |||
| if __name__ == "__main__": | |||
| test_train() | |||
| @@ -42,6 +42,13 @@ reciprocal = P.Reciprocal() | |||
| def tensor_grad_scale(scale, grad): | |||
| return grad * reciprocal(scale) | |||
| _grad_overflow = C.MultitypeFuncGraph("_grad_overflow") | |||
| grad_overflow = P.FloatStatus() | |||
| @_grad_overflow.register("Tensor") | |||
| def _tensor_grad_overflow(grad): | |||
| return grad_overflow(grad) | |||
| class BertFinetuneCell(nn.Cell): | |||
| """ | |||
| Especifically defined for finetuning where only four inputs tensor are needed. | |||
| @@ -67,9 +74,16 @@ class BertFinetuneCell(nn.Cell): | |||
| self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) | |||
| self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) | |||
| self.cast = P.Cast() | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.gpu_target = False | |||
| if context.get_context("device_target") == "GPU": | |||
| self.gpu_target = True | |||
| self.float_status = P.FloatStatus() | |||
| self.addn = P.AddN() | |||
| self.reshape = P.Reshape() | |||
| else: | |||
| self.alloc_status = P.NPUAllocFloatStatus() | |||
| self.get_status = P.NPUGetFloatStatus() | |||
| self.clear_before_grad = P.NPUClearFloatStatus() | |||
| self.reduce_sum = P.ReduceSum(keep_dims=False) | |||
| self.depend_parameter_use = P.ControlDepend(depend_mode=1) | |||
| self.base = Tensor(1, mstype.float32) | |||
| @@ -90,7 +104,7 @@ class BertFinetuneCell(nn.Cell): | |||
| weights = self.weights | |||
| init = self.alloc_status() | |||
| init = False | |||
| loss = self.network(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| @@ -99,28 +113,36 @@ class BertFinetuneCell(nn.Cell): | |||
| scaling_sens = self.loss_scale | |||
| else: | |||
| scaling_sens = sens | |||
| if not self.gpu_target: | |||
| init = self.alloc_status() | |||
| clear_before_grad = self.clear_before_grad(init) | |||
| F.control_depend(loss, init) | |||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||
| grads = self.grad(self.network, weights)(input_ids, | |||
| input_mask, | |||
| token_type_id, | |||
| label_ids, | |||
| self.cast(scaling_sens, | |||
| mstype.float32)) | |||
| clear_before_grad = self.clear_before_grad(init) | |||
| F.control_depend(loss, init) | |||
| self.depend_parameter_use(clear_before_grad, scaling_sens) | |||
| grads = self.hyper_map(F.partial(grad_scale, scaling_sens), grads) | |||
| grads = self.hyper_map(F.partial(clip_grad, GRADIENT_CLIP_TYPE, GRADIENT_CLIP_VALUE), grads) | |||
| if self.reducer_flag: | |||
| grads = self.grad_reducer(grads) | |||
| flag = self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| if not self.gpu_target: | |||
| flag = self.get_status(init) | |||
| flag_sum = self.reduce_sum(init, (0,)) | |||
| F.control_depend(grads, flag) | |||
| F.control_depend(flag, flag_sum) | |||
| else: | |||
| flag_sum = self.hyper_map(F.partial(_grad_overflow), grads) | |||
| flag_sum = self.addn(flag_sum) | |||
| flag_sum = self.reshape(flag_sum, (())) | |||
| if self.is_distributed: | |||
| flag_reduce = self.allreduce(flag_sum) | |||
| cond = self.less_equal(self.base, flag_reduce) | |||
| else: | |||
| cond = self.less_equal(self.base, flag_sum) | |||
| F.control_depend(grads, flag) | |||
| F.control_depend(flag, flag_sum) | |||
| overflow = cond | |||
| if sens is None: | |||
| overflow = self.loss_scaling_manager(self.loss_scale, cond) | |||