Browse Source

fix bug in gpu/array

tags/v1.4.0
zhaoting 4 years ago
parent
commit
f356eaca03
9 changed files with 72 additions and 18 deletions
  1. +4
    -2
      mindspore/ccsrc/backend/kernel_compiler/cpu/iou_cpu_kernel.cc
  2. +5
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h
  3. +8
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h
  4. +11
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h
  5. +11
    -2
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h
  6. +6
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h
  7. +5
    -1
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h
  8. +12
    -0
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h
  9. +10
    -5
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h

+ 4
- 2
mindspore/ccsrc/backend/kernel_compiler/cpu/iou_cpu_kernel.cc View File

@@ -64,7 +64,7 @@ bool IOUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, cons

// multithreading
auto task = [&anchor_boxes, &gt_boxes, &iou_score, this](size_t start, size_t end) {
const T ZERO = T(1);
const T ZERO = T(0);
const T ONE = T(1);
const T EPS = T(1e-10);
constexpr size_t Y0_SHIFT = 1;
@@ -77,7 +77,9 @@ bool IOUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, cons
T I_y0 = std::max(anchor_boxes[idx1 + Y0_SHIFT], gt_boxes[idx2 + Y0_SHIFT]);
T I_x1 = std::min(anchor_boxes[idx1 + X1_SHIFT], gt_boxes[idx2 + X1_SHIFT]);
T I_y1 = std::min(anchor_boxes[idx1 + Y1_SHIFT], gt_boxes[idx2 + Y1_SHIFT]);
T overlaps = std::max(ZERO, (I_x1 - I_x0 + ONE) * (I_y1 - I_y0 + ONE));
T overlaps_w = std::max(ZERO, (I_x1 - I_x0 + ONE));
T overlaps_h = std::max(ZERO, (I_y1 - I_y0 + ONE));
T overlaps = overlaps_w * overlaps_h;
T area1 = (anchor_boxes[idx1 + X1_SHIFT] - anchor_boxes[idx1] + ONE) *
(anchor_boxes[idx1 + Y1_SHIFT] - anchor_boxes[idx1 + Y0_SHIFT] + ONE);
T area2 = (gt_boxes[idx2 + X1_SHIFT] - gt_boxes[idx2] + ONE) *


+ 5
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/argmaxandminwithvalue_gpu_kernel.h View File

@@ -50,8 +50,12 @@ class ArgMaxAndMinWithValueGpuKernel : public GpuKernel {
small_ = (kernel_name == "ArgMinWithValue") ? true : false;
std::vector<size_t> shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 1);
int64_t dims = shape.size();
int64_t dims = SizeToLong(shape.size());
int64_t axis = GetAttr<int64_t>(kernel_node, "axis");
if (axis < -dims || axis >= dims) {
MS_LOG(ERROR) << "axis must be in the range [-rank, rank)";
return false;
}
if (axis < 0) {
axis += dims;
}


+ 8
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/concatv2_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -66,10 +66,15 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
if (!CheckParam(kernel_node)) {
return false;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
int dims = SizeToInt(input_shape.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(ERROR) << "axis must be in the range [-rank, rank)";
return false;
}
if (axis_ < 0) {
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
axis_ += SizeToInt(input_shape.size());
axis_ += dims;
}
auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node);
auto input_format = AnfAlgo::GetInputFormat(kernel_node, 0);


+ 11
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -54,10 +54,18 @@ class GatherGpuFwdKernel : public GpuKernel {
input_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
index_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);

if (input_shapes_.size() != index_shapes_.size() || input_shapes_.size() != output_shapes_.size()) {
MS_LOG(ERROR) << "The shape of input, index and output should be same.";
return false;
}
int dims = SizeToInt(input_shapes_.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "dim"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(ERROR) << "axis must be in the range [-rank, rank)";
return false;
}
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(input_shapes_.size());
axis_ += dims;
}
Reshape();
InitSizeLists();


+ 11
- 2
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gather_grad_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -55,9 +55,18 @@ class GatherGradGpuKernel : public GpuKernel {
grad_shapes_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);

if (grad_shapes_.size() != index_shapes_.size() || grad_shapes_.size() != output_shapes_.size()) {
MS_LOG(ERROR) << "The shape of grad, index and output should be same.";
return false;
}
int dims = SizeToInt(grad_shapes_.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "dim"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(ERROR) << "axis must be in the range [-rank, rank)";
return false;
}
if (axis_ < 0) {
axis_ = axis_ + SizeToInt(index_shapes_.size());
axis_ += dims;
}

Reshape();


+ 6
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/gatherv2_gpu_kernel.h View File

@@ -76,7 +76,12 @@ class GatherV2GpuFwdKernel : public GpuKernel {
indices_shapes_ = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 1);
output_shapes_ = AnfAlgo::GetOutputRealDeviceShapeIfExist(kernel_node, 0);
if (!is_dynamic_shape_) {
int dims = SizeToInt(input_shapes_.size());
axis_ = static_cast<int>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(ERROR) << "axis must be in the range [-rank, rank)";
return false;
}
Reshape();
}
InitSizeLists();
@@ -113,7 +118,7 @@ class GatherV2GpuFwdKernel : public GpuKernel {
axis_ = axis_ + SizeToInt(input_shapes_.size());
}
size_t dim_before_axis = 1;
for (size_t i = 0; i < IntToSize(axis_); i++) {
for (size_t i = 0; i < std::min(IntToSize(axis_), output_shapes_.size()); i++) {
dim_before_axis *= output_shapes_[i];
}
size_t dim_of_indices = 1;


+ 5
- 1
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_gpu_kernel.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
* Copyright 2019-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -66,6 +66,10 @@ class ResizeNearestNeighborGpuKernel : public GpuKernel {
<< RESIZENEARESTNEIGHBOR_DIMENSION << "-D inputs.";
return false;
}
if (shape_size_ != output_shape.size()) {
MS_LOG(ERROR) << "The dim of input and output must be same.";
return false;
}
input_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
input_size_ *= input_shape[i];


+ 12
- 0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/resize_nearest_neighbor_grad_gpu_kernel.h View File

@@ -66,15 +66,27 @@ class ResizeNearestNeighborGradGpuKernel : public GpuKernel {
<< RESIZENEARESTNEIGHBORGRAD_DIMENSION << "-D inputs.";
return false;
}
if (shape_size_ != output_shape.size()) {
MS_LOG(ERROR) << "The dim of input and output must be same.";
return false;
}
input_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
input_size_ *= input_shape[i];
if (input_shape[i] == 0) {
MS_LOG(ERROR) << "The shape of input has 0.";
return false;
}
input_shape_.push_back(input_shape[i]);
}
input_size_ *= sizeof(T);
output_size_ = 1;
for (size_t i = 0; i < shape_size_; i++) {
output_size_ *= output_shape[i];
if (input_shape[i] == 0) {
MS_LOG(ERROR) << "The shape of output has 0.";
return false;
}
output_shape_.push_back(output_shape[i]);
}
output_size_ *= sizeof(T);


+ 10
- 5
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/split_gpu_kernel.h View File

@@ -52,10 +52,14 @@ class SplitGpuFwdKernel : public GpuKernel {

bool Init(const CNodePtr &kernel_node) override {
kernel_node_ = kernel_node;
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
int dims = SizeToInt(input_shape.size());
axis_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "axis"));
if (axis_ < -dims || axis_ >= dims) {
MS_LOG(EXCEPTION) << "axis must be in the range [-rank, rank)";
}
if (axis_ < 0) {
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
axis_ += SizeToInt(input_shape.size());
axis_ += dims;
}

auto origin_data_format = AnfAlgo::GetOriginDataFormat(kernel_node);
@@ -67,8 +71,6 @@ class SplitGpuFwdKernel : public GpuKernel {
if (!CheckParam(kernel_node)) {
return false;
}

auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
input_size_ = 1;
all_size_before_axis_ = 1;
all_size_axis_ = 1;
@@ -122,7 +124,10 @@ class SplitGpuFwdKernel : public GpuKernel {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
int dims = SizeToInt(input_shape.size());
int output_num = SizeToInt(AnfAlgo::GetOutputTensorNum(kernel_node));

if (output_num <= 0) {
MS_LOG(ERROR) << "Output number is " << output_num << ", must > 0.";
return false;
}
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but Split needs 1 input.";
return false;


Loading…
Cancel
Save