Browse Source

remove CTCLossV2

tags/v1.1.0
baihuawei 5 years ago
parent
commit
fea928e976
11 changed files with 31 additions and 383 deletions
  1. +0
    -31
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.cc
  2. +0
    -192
      mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.h
  3. +0
    -13
      mindspore/ops/_grad/grad_nn_ops.py
  4. +1
    -1
      mindspore/ops/operations/__init__.py
  5. +0
    -54
      mindspore/ops/operations/nn_ops.py
  6. +4
    -5
      model_zoo/official/cv/warpctc/eval.py
  7. +3
    -9
      model_zoo/official/cv/warpctc/src/dataset.py
  8. +0
    -22
      model_zoo/official/cv/warpctc/src/loss.py
  9. +1
    -1
      model_zoo/official/cv/warpctc/src/metric.py
  10. +6
    -8
      model_zoo/official/cv/warpctc/train.py
  11. +16
    -47
      tests/st/ops/gpu/test_ctcloss_op.py

+ 0
- 31
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.cc View File

@@ -1,31 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.h"

namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(CTCLossV2,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
CtcLossV2GpuKernel, float)
} // namespace kernel
} // namespace mindspore

+ 0
- 192
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/ctclossv2_gpu_kernel.h View File

@@ -1,192 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_

#include <cuda_runtime_api.h>
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "runtime/device/gpu/gpu_memory_allocator.h"

namespace mindspore {
namespace kernel {
template <typename T>
class CtcLossV2GpuKernel : public GpuKernel {
public:
CtcLossV2GpuKernel()
: cudnn_handle_(nullptr),
probs_desc_(nullptr),
ctcloss_desc_(nullptr),
label_size_(0),
input_lengths_size_(0),
label_lengths_size_(0) {}
~CtcLossV2GpuKernel() override { DestroyResource(); }

const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
float *probs = GetDeviceAddress<float>(inputs, 0);
float *costs = GetDeviceAddress<float>(outputs, 0);
float *grads = GetDeviceAddress<float>(outputs, 1);

// Copy labels/input_lengths/label_length to host as cudnn7.x.x requires
int *labels_host = nullptr;
int *no_blank_labels_host = nullptr;
void *input_lengths_host = nullptr;
void *label_lengths_host = nullptr;
cudaStream_t stream = reinterpret_cast<cudaStream_t>(stream_ptr);
AllocHostMem(&labels_host, &no_blank_labels_host, &input_lengths_host, &label_lengths_host, inputs);
CopyToHostSync(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host, inputs, stream);

size_t workspace_size = 0;
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnGetCTCLossWorkspaceSize(
cudnn_handle_, probs_desc_, probs_desc_, reinterpret_cast<int *>(no_blank_labels_host),
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host),
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, &workspace_size),
"cudnnGetCTCLossWorkspaceSize failed.");
void *workspace = device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(workspace_size);
if (workspace == nullptr) {
MS_LOG(EXCEPTION) << "Failed to alloc workspace, size: " << workspace_size;
}

CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnCTCLoss(cudnn_handle_, probs_desc_, probs, reinterpret_cast<int *>(no_blank_labels_host),
reinterpret_cast<int *>(label_lengths_host), reinterpret_cast<int *>(input_lengths_host), costs,
probs_desc_, grads, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, ctcloss_desc_, workspace, workspace_size),
"cudnnCtcLoss failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");

device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(workspace);
FreeHostMem(labels_host, no_blank_labels_host, input_lengths_host, label_lengths_host);
return true;
}
bool Init(const CNodePtr &kernel_node) override {
InitResource();
auto probs_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
if (probs_shape.size() != 3) {
MS_LOG(EXCEPTION) << "probs dims: " << probs_shape.size() << " not support.";
}
probs_dims_[0] = probs_shape[0];
probs_dims_[1] = probs_shape[1];
probs_dims_[2] = probs_shape[2];

auto labels_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
if (labels_dims.size() != 1 && labels_dims.size() != 2) {
MS_LOG(EXCEPTION) << "labels dims: " << labels_dims.size() << " not support.";
}
label_size_ = sizeof(int);
for (auto i : labels_dims) {
label_size_ *= i;
}

auto input_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
input_lengths_size_ = input_length_dims[0] * sizeof(int);
auto label_length_dims = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 3);
label_lengths_size_ = label_length_dims[0] * sizeof(int);
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensorNdDescriptorEx(probs_desc_, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 3, probs_dims_),
"cudnnSetTensorNdDescriptorEx failed.");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetCTCLossDescriptorEx(ctcloss_desc_, CUDNN_DATA_FLOAT,
CUDNN_LOSS_NORMALIZATION_SOFTMAX, CUDNN_PROPAGATE_NAN),
"cudnnSetCTCLossDescriptorEx failed.");
InitSizeLists();
return true;
}

protected:
void InitResource() override {
cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle();
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateTensorDescriptor(&probs_desc_), "cudnnCreateTensorDescriptor failed.");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreateCTCLossDescriptor(&ctcloss_desc_), "cudnnCreateCTCLossDescriptor failed.");
}

void InitSizeLists() override {
input_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
input_size_list_.push_back(label_size_);
input_size_list_.push_back(input_lengths_size_);
input_size_list_.push_back(label_lengths_size_);

output_size_list_.push_back(probs_dims_[1] * sizeof(float));
output_size_list_.push_back(probs_dims_[0] * probs_dims_[1] * probs_dims_[2] * sizeof(float));
}

private:
void DestroyResource() noexcept {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyCTCLossDescriptor(ctcloss_desc_), "cudnnDestroyCTCLossDescriptor failed.");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroyTensorDescriptor(probs_desc_), "cudnnDestroyTensorDescriptor failed.");
}

void AllocHostMem(int **labels_host, int **no_blank_labels_host, void **input_lengths_host, void **label_lengths_host,
const std::vector<AddressPtr> &inputs) {
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(labels_host, inputs[1]->size), "cudaMallocHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(no_blank_labels_host, inputs[1]->size), "cudaMallocHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(input_lengths_host, inputs[2]->size), "cudaMallocHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaMallocHost(label_lengths_host, inputs[3]->size), "cudaMallocHost failed.");
}

void FreeHostMem(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host) {
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(label_lengths_host), "cudaFreeHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(input_lengths_host), "cudaFreeHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(labels_host), "cudaFreeHost failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaFreeHost(no_blank_labels_host), "cudaFreeHost failed.");
}

void CopyToHostSync(int *labels_host, int *no_blank_labels_host, void *input_lengths_host, void *label_lengths_host,
const std::vector<AddressPtr> &inputs, cudaStream_t stream) {
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(labels_host, inputs[1]->addr, inputs[1]->size, cudaMemcpyDeviceToHost, stream),
"cudaMemcpyAsync failed.");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(input_lengths_host, inputs[2]->addr, inputs[2]->size, cudaMemcpyDeviceToHost, stream),
"cudaMemcpyAsync failed.");
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemcpyAsync(label_lengths_host, inputs[3]->addr, inputs[3]->size, cudaMemcpyDeviceToHost, stream),
"cudaMemcpyAsync failed.");
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(stream), "cudaStreamSynchronize failed.");

// remove blank element
size_t j = 0;
for (size_t i = 0; i < inputs[1]->size / sizeof(int); i++) {
if (labels_host[i] != 0) {
no_blank_labels_host[j] = labels_host[i];
j++;
}
}
}

std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;

cudnnHandle_t cudnn_handle_;
cudnnTensorDescriptor_t probs_desc_;
cudnnCTCLossDescriptor_t ctcloss_desc_;
int probs_dims_[3] = {0};
int label_size_;
int input_lengths_size_;
int label_lengths_size_;
};
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_CTCLOSS_GPU_KERNEL_H_

+ 0
- 13
mindspore/ops/_grad/grad_nn_ops.py View File

@@ -987,19 +987,6 @@ def get_bprop_ctc_loss(self):
return bprop return bprop




@bprop_getters.register(P.CTCLossV2)
def get_bprop_ctc_loss_v2(self):
"""Grad definition for `CTCLossV2` operation"""
expand = P.ExpandDims()

def bprop(inputs, labels, input_lengths, labels_lengths, out, dout):
grad_loss = out[1]
grad = grad_loss * expand(dout[0], -1)
return grad, zeros_like(labels), zeros_like(input_lengths), zeros_like(labels_lengths)

return bprop


@bprop_getters.register(P.BasicLSTMCell) @bprop_getters.register(P.BasicLSTMCell)
def get_bprop_basic_lstm_cell(self): def get_bprop_basic_lstm_cell(self):
"""Grad definition for `BasicLSTMCell` operation.""" """Grad definition for `BasicLSTMCell` operation."""


+ 1
- 1
mindspore/ops/operations/__init__.py View File

@@ -64,7 +64,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Appl
DropoutDoMask, Dropout, DropoutDoMask, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate, DropoutGenMask, Flatten, FusedBatchNorm, FusedBatchNormEx, BNTrainingReduce, BNTrainingUpdate,
Gelu, Elu, Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCLossV2, CTCGreedyDecoder,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, CTCGreedyDecoder,
LogSoftmax, LogSoftmax,
MaxPool, DataFormatDimMap, MaxPool, DataFormatDimMap,
AvgPool, Conv2DBackpropInput, AvgPool, Conv2DBackpropInput,


+ 0
- 54
mindspore/ops/operations/nn_ops.py View File

@@ -5605,57 +5605,3 @@ class LRN(PrimitiveWithInfer):
def infer_shape(self, x_shape): def infer_shape(self, x_shape):
validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name) validator.check_integer("x_shape", len(x_shape), 4, Rel.EQ, self.name)
return x_shape return x_shape


class CTCLossV2(PrimitiveWithInfer):
r"""
Calculates the CTC (Connectionist Temporal Classification) loss and the gradient.
Note:
- Cudnn Uses label value of for the `blank`

Inputs:
- **inputs** (Tensor) - The input Tensor must be a `3-D` tensor whose shape is
:math:`(max_time, batch_size, num_class)`. `num_class` must be `num_labels + 1` classes, `num_labels`
indicates the number of actual labels. Blank labels are reserved.
- **labels** (Tensor) - The labels Tensor must be a `1-D` tensor whose shape is
:math:`(\sigma{label_lengths})`
or `2-D` tensor whose shape is
:math:`(max_time, max{label_lengths})`
The type must be int32.
- **input_lengths** (Tensor) - A `1-D` input tensor whose shape is
:math:`(batch_size,)`. The values must be batch. The type must be int32.
- **label_lengths** (Tensor) - A tensor containing sequence lengths with the shape of :math:`(batch_size)`.
The type must be int32. Each value in the tensor must not greater than `max_time`.

Outputs:
- **loss** (Tensor) - A tensor containing log-probabilities, the shape is :math:`(batch_size)`, has the same
type with `inputs`.
- **gradient** (Tensor) - The gradient of `loss`, has the same type and shape with `inputs`.

Examples:
>>> inputs = Tensor(np.random.random((2, 2, 3)), mindspore.float32)
>>> labels = Tensor(np.array([[0, 0], [1, 0]]), mindspore.int32)
>>> input_lengths = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> label_lengths = Tensor(np.array([3, 3, 3]), mindspore.int32)
>>> ctc_loss = P.CTCLossV2()
>>> output = ctc_loss(inputs, labels, input_lengths, label_lengths)
"""
@prim_attr_register
def __init__(self):
pass

def infer_dtype(self, input_dtype, labels_dtype, input_lengths_dtype, label_lengths_dtype):
validator.check_tensor_type_same({"input": input_dtype}, (mstype.float32,), self.name)
validator.check_tensor_type_same({"labels": labels_dtype}, (mstype.int32,), self.name)
validator.check_tensor_type_same({"input_lengths": input_lengths_dtype}, (mstype.int32,), self.name)
validator.check_tensor_type_same({"target_lengths": label_lengths_dtype}, (mstype.int32,), self.name)
return mstype.float32, mstype.float32

def infer_shape(self, input_shape, labels_shape, input_lengths_shape, label_lengths_shape):
validator.check_integer("input shape", len(input_shape), 3, Rel.EQ, self.name)
validator.check_number_range("labels shape", len(labels_shape), 1, 2, Rel.INC_BOTH, self.name)
validator.check_integer("input lengths shape", len(input_lengths_shape), 1, Rel.EQ, self.name)
validator.check_integer("label lengths shape", len(label_lengths_shape), 1, Rel.EQ, self.name)
validator.check_integer("input[1]", input_shape[1], input_lengths_shape[0], Rel.EQ, self.name)
validator.check_integer("input[1]", input_shape[1], label_lengths_shape[0], Rel.EQ, self.name)
return (input_shape[1],), input_shape

+ 4
- 5
model_zoo/official/cv/warpctc/eval.py View File

@@ -21,7 +21,7 @@ from mindspore.common import set_seed
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net


from src.loss import CTCLoss, CTCLossV2
from src.loss import CTCLoss
from src.config import config as cf from src.config import config as cf
from src.dataset import create_dataset from src.dataset import create_dataset
from src.warpctc import StackedRNN, StackedRNNForGPU from src.warpctc import StackedRNN, StackedRNNForGPU
@@ -49,13 +49,12 @@ if __name__ == '__main__':
batch_size=cf.batch_size, batch_size=cf.batch_size,
device_target=args_opt.platform) device_target=args_opt.platform)
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
if args_opt.platform == 'Ascend': if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else: else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)


# load checkpoint # load checkpoint


+ 3
- 9
model_zoo/official/cv/warpctc/src/dataset.py View File

@@ -41,7 +41,7 @@ class _CaptchaDataset:
self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
self.max_captcha_digits = max_captcha_digits self.max_captcha_digits = max_captcha_digits
self.target = device_target self.target = device_target
self.blank = 10 if self.target == 'Ascend' else 0
self.blank = 10
self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names] self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names]


def __len__(self): def __len__(self):
@@ -55,14 +55,8 @@ class _CaptchaDataset:
image = np.array(im) image = np.array(im)
label_str = os.path.splitext(img_name)[0] label_str = os.path.splitext(img_name)[0]
label_str = label_str[label_str.find('-') + 1:] label_str = label_str[label_str.find('-') + 1:]
if self.target == 'Ascend':
label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
else:
label = [int(i) + 1 for i in label_str]
length = len(label)
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
label.append(length)
label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
label = np.array(label) label = np.array(label)
return image, label return image, label




+ 0
- 22
model_zoo/official/cv/warpctc/src/loss.py View File

@@ -47,25 +47,3 @@ class CTCLoss(_Loss):
labels_values = self.reshape(label, (-1,)) labels_values = self.reshape(label, (-1,))
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
return loss return loss


class CTCLossV2(_Loss):
"""
CTCLoss definition

Args:
max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image width
batch_size(int): batch size of input logits
"""

def __init__(self, max_sequence_length, batch_size):
super(CTCLossV2, self).__init__()
self.input_length = Tensor(np.array([max_sequence_length] * batch_size), mstype.int32)
self.reshape = P.Reshape()
self.ctc_loss = P.CTCLossV2()

def construct(self, logit, label):
labels_values = label[:, :-1]
labels_length = label[:, -1]
loss, _ = self.ctc_loss(logit, labels_values, self.input_length, labels_length)
return loss

+ 1
- 1
model_zoo/official/cv/warpctc/src/metric.py View File

@@ -27,7 +27,7 @@ class WarpCTCAccuracy(nn.Metric):
self._total_num = 0 self._total_num = 0
self._count = 0 self._count = 0
self.device_target = device_target self.device_target = device_target
self.blank = 10 if device_target == 'Ascend' else 0
self.blank = 10


def clear(self): def clear(self):
self._correct_num = 0 self._correct_num = 0


+ 6
- 8
model_zoo/official/cv/warpctc/train.py View File

@@ -25,7 +25,7 @@ from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init, get_group_size, get_rank from mindspore.communication.management import init, get_group_size, get_rank


from src.loss import CTCLoss, CTCLossV2
from src.loss import CTCLoss
from src.config import config as cf from src.config import config as cf
from src.dataset import create_dataset from src.dataset import create_dataset
from src.warpctc import StackedRNN, StackedRNNForGPU from src.warpctc import StackedRNN, StackedRNNForGPU
@@ -58,7 +58,7 @@ if __name__ == '__main__':
rank = int(os.environ.get("RANK_ID")) rank = int(os.environ.get("RANK_ID"))
else: else:
init() init()
lr_scale = 0.5
lr_scale = 1
device_num = get_group_size() device_num = get_group_size()
rank = get_rank() rank = get_rank()
context.reset_auto_parallel_context() context.reset_auto_parallel_context()
@@ -78,16 +78,14 @@ if __name__ == '__main__':
# define lr # define lr
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale
lr = get_lr(cf.epoch_size, step_size, lr_init) lr = get_lr(cf.epoch_size, step_size, lr_init)
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
if args_opt.platform == 'Ascend': if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
else: else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)


net = WithLossCell(net, loss) net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train() net = TrainOneStepCellWithGradClip(net, opt).set_train()


+ 16
- 47
tests/st/ops/gpu/test_ctcloss_op.py View File

@@ -23,28 +23,28 @@ from mindspore.ops import operations as P
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops.composite import GradOperation from mindspore.ops.composite import GradOperation



class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.loss = P.CTCLossV2()
self.loss = P.CTCLoss()
self.div = P.RealDiv() self.div = P.RealDiv()
self.cast = P.Cast()
self.mean = P.ReduceMean() self.mean = P.ReduceMean()


def construct(self, probs, label, input_length, label_length):
x, _ = self.loss(probs, label, input_length, label_length)
x = self.div(x, self.cast(label_length, mstype.float32))
def construct(self, probs, label, input_length, indices):
x, _ = self.loss(probs, indices, label, input_length)
x = self.mean(x) x = self.mean(x)
return x return x



class GradData(nn.Cell): class GradData(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(GradData, self).__init__() super(GradData, self).__init__()
self.grad = GradOperation(get_all=True, sens_param=False) self.grad = GradOperation(get_all=True, sens_param=False)
self.network = network self.network = network


def construct(self, probs, labels, input_lengths, label_lengths):
return self.grad(self.network)(probs, labels, input_lengths, label_lengths)
def construct(self, probs, indices, labels, input_lengths):
return self.grad(self.network)(probs, indices, labels, input_lengths)




@pytest.mark.level0 @pytest.mark.level0
@@ -71,49 +71,18 @@ def test_ctcloss():
[-3.3031, -3.0087, -1.9982, -1.9081, -3.8731, -2.8764, -2.2485, -2.3808, -1.4283, -2.1625], [-3.3031, -3.0087, -1.9982, -1.9081, -3.8731, -2.8764, -2.2485, -2.3808, -1.4283, -2.1625],
[-2.4516, -3.2394, -4.2053, -4.3541, -2.5229, -4.0717, -1.4894, -2.3151, -1.1098, -2.3465]]], [-2.4516, -3.2394, -4.2053, -4.3541, -2.5229, -4.0717, -1.4894, -2.3151, -1.1098, -2.3465]]],
dtype=mstype.float32) dtype=mstype.float32)
labels = Tensor([9, 4, 6, 4, 7, 1, 4, 6, 6, 8], dtype=mstype.int32)
labels = Tensor([3, 4, 6, 4, 7, 1, 4, 6, 6, 8], dtype=mstype.int32)
indices = [[0, 0], [0, 1], [0, 2], [1, 0], [1, 1], [1, 2], [2, 0], [2, 1], [2, 2], [2, 3]]
indices = Tensor(indices, dtype=mstype.int64)
input_lengths = Tensor([5, 5, 5], dtype=mstype.int32) input_lengths = Tensor([5, 5, 5], dtype=mstype.int32)
label_lengths = Tensor([3, 3, 4], dtype=mstype.int32)


context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = Net() net = Net()
ctc_loss = net(probs, labels, input_lengths, label_lengths)
expect_loss = [2.4099]
ctc_loss = net(probs, labels, input_lengths, indices)
expect_loss = [9.083767]
assert np.allclose(ctc_loss.asnumpy(), expect_loss) assert np.allclose(ctc_loss.asnumpy(), expect_loss)


grad = GradData(net)(probs, labels, input_lengths, label_lengths)
expect_grad = [[[8.8442e-05, 1.1065e-03, 3.5867e-03, 2.1896e-03, 6.1646e-03,
3.6738e-03, 1.6262e-03, 3.5610e-02, 9.1258e-05, -5.4134e-02],
[-3.7523e-03, 3.9386e-03, 7.9623e-04, 3.1132e-02, -6.2954e-02,
9.4143e-03, 7.6425e-03, 1.7902e-03, 7.4211e-03, 4.5719e-03],
[6.7778e-03, 1.6178e-02, 1.0344e-02, 1.5173e-03, -6.5840e-02,
8.1707e-03, 6.9674e-03, 4.1814e-03, 3.6026e-03, 8.0991e-03]],

[[-1.2581e-02, 3.1057e-03, 4.9517e-03, 1.3301e-03, -2.6320e-02,
1.5568e-02, 1.4305e-02, 9.6671e-03, 1.7262e-02, -2.7292e-02],
[-1.5566e-02, 3.3126e-03, 2.6887e-02, 6.2993e-03, -3.9716e-02,
1.1420e-02, 7.4531e-03, -1.4252e-02, 8.5603e-03, 5.6048e-03],
[3.3483e-03, 2.0579e-02, 3.7231e-03, 1.5832e-03, 2.4837e-03,
3.2909e-03, -7.7267e-02, 1.3861e-02, 1.3558e-02, 1.4840e-02]],

[[-8.0007e-03, 1.2751e-02, 4.3901e-02, 5.8435e-03, -7.2627e-02,
1.4647e-02, -8.0584e-03, 4.4595e-03, 6.5557e-03, 5.2891e-04],
[-3.6006e-02, 1.5308e-03, 9.3225e-03, 1.0969e-03, -2.5098e-03,
2.0260e-02, 2.3419e-02, -3.0053e-02, 1.8809e-03, 1.1059e-02],
[-7.7639e-02, 1.8533e-02, 2.0764e-03, 5.9706e-03, 5.6150e-03,
5.6868e-03, 5.2854e-03, 9.8085e-03, 2.0360e-02, 4.3053e-03]],

[[-2.6776e-02, 1.1113e-02, 3.8314e-03, 3.9986e-02, -1.6020e-02,
1.1579e-02, -4.1635e-02, 5.5992e-03, 2.7429e-03, 9.5786e-03],
[-6.8619e-03, -6.4066e-03, 1.0888e-02, 1.4201e-02, 1.4413e-03,
1.3225e-02, 8.0039e-03, -4.9191e-02, 5.6352e-03, 9.0651e-03],
[5.1026e-03, 1.9343e-03, 3.2506e-03, 1.0347e-03, 2.9837e-02,
1.7121e-03, -5.9261e-02, 9.1443e-04, 8.3608e-03, 7.1146e-03]],

[[-2.0848e-02, 7.0754e-03, 2.7633e-03, 2.4447e-03, 3.1520e-02,
7.5401e-03, -5.8895e-02, 8.9559e-04, 5.7796e-03, 2.1724e-02],
[-1.3499e-03, -1.0019e-01, 1.5064e-02, 1.6485e-02, 2.3104e-03,
6.2597e-03, 1.1729e-02, 1.0275e-02, 2.6635e-02, 1.2782e-02],
[7.1796e-03, 3.2656e-03, 1.2430e-03, 1.0712e-03, 6.6856e-03,
1.4207e-03, 1.8792e-02, 8.2297e-03, -5.5865e-02, 7.9753e-03]]]
assert np.allclose(grad[0].asnumpy(), expect_grad, atol=1e-5)
grad = GradData(net)(probs, labels, input_lengths, indices)
grad = P.ReduceMean()(grad[0])
expect_grad = [-5.9604646e-09]
assert np.allclose(grad.asnumpy(), expect_grad, atol=1e-5)

Loading…
Cancel
Save