Browse Source

!8806 GPU Stridedslice support `new_axis_mask`

From: @wilfchen
Reviewed-by: @limingqi107,@kisnwang
Signed-off-by: @kisnwang
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
ecdaeebd43
2 changed files with 32 additions and 12 deletions
  1. +16
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h
  2. +16
    -6
      mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h

+ 16
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_gpu_kernel.h View File

@@ -143,12 +143,22 @@ class StridedSliceGpuKernel : public GpuKernel {
} }
} }


auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str);
for (size_t l = 0; l < shrink_axis_mask.size(); l++) {
if (shrink_axis_mask[l]) {
end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1;
strides_[l] = end_[l] > begin_[l] ? 1 : -1;
auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask"));
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
for (size_t l = 0; l < new_axis_mask.size(); l++) {
if (new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = input_shape_[l];
strides_[l] = 1;
}
}

auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
if (shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
} }
} }
} }


+ 16
- 6
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/strided_slice_grad_gpu_kernel.h View File

@@ -149,12 +149,22 @@ class StridedSliceGradGpuKernel : public GpuKernel {
} }
} }


auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str);
for (size_t l = 0; l < shrink_axis_mask.size(); l++) {
if (shrink_axis_mask[l]) {
end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1;
strides_[l] = end_[l] > begin_[l] ? 1 : -1;
auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask"));
auto new_axis_mask = Dec2Bin(new_axis_mask_int);
for (size_t l = 0; l < new_axis_mask.size(); l++) {
if (new_axis_mask[l]) {
begin_[l] = 0;
end_[l] = input_shape_[l];
strides_[l] = 1;
}
}

auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask"));
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int);
for (size_t m = 0; m < shrink_axis_mask.size(); m++) {
if (shrink_axis_mask[m]) {
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1;
strides_[m] = end_[m] > begin_[m] ? 1 : -1;
} }
} }
} }


Loading…
Cancel
Save