|
|
|
@@ -43,6 +43,10 @@ class StridedSliceGpuCommon { |
|
|
|
strides_ = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "strides"); |
|
|
|
|
|
|
|
for (size_t i = 0; i < MAX_DIMS; i++) { |
|
|
|
if (i >= input_shape_.size()) { |
|
|
|
input_shape_.push_back(1); |
|
|
|
} |
|
|
|
|
|
|
|
if (i < begin_.size()) { |
|
|
|
int64_t dim = input_shape_[i]; |
|
|
|
begin_[i] = std::min(begin_[i] < 0 ? std::max(begin_[i] + dim, static_cast<int64_t>(0)) : begin_[i], dim - 1); |
|
|
|
@@ -60,10 +64,6 @@ class StridedSliceGpuCommon { |
|
|
|
if (i >= strides_.size()) { |
|
|
|
strides_.push_back(1); |
|
|
|
} |
|
|
|
|
|
|
|
if (i >= input_shape_.size()) { |
|
|
|
input_shape_.push_back(1); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -71,7 +71,7 @@ class StridedSliceGpuCommon { |
|
|
|
auto begin_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "begin_mask"); |
|
|
|
auto begin_mask = Dec2Bin(begin_mask_int); |
|
|
|
for (size_t i = 0; i < begin_mask.size(); i++) { |
|
|
|
if (begin_mask[i]) { |
|
|
|
if (begin_mask[i] && i < MAX_DIMS) { |
|
|
|
begin_[i] = 0; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -79,7 +79,7 @@ class StridedSliceGpuCommon { |
|
|
|
auto end_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "end_mask"); |
|
|
|
auto end_mask = Dec2Bin(end_mask_int); |
|
|
|
for (size_t j = 0; j < end_mask.size(); j++) { |
|
|
|
if (end_mask[j]) { |
|
|
|
if (end_mask[j] && j < MAX_DIMS) { |
|
|
|
end_[j] = input_shape_[j]; |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -87,7 +87,7 @@ class StridedSliceGpuCommon { |
|
|
|
auto ellipsis_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "ellipsis_mask"); |
|
|
|
auto ellipsis_mask = Dec2Bin(ellipsis_mask_int); |
|
|
|
for (size_t k = 0; k < ellipsis_mask.size(); k++) { |
|
|
|
if (ellipsis_mask[k]) { |
|
|
|
if (ellipsis_mask[k] && k < MAX_DIMS) { |
|
|
|
begin_[k] = 0; |
|
|
|
end_[k] = input_shape_[k]; |
|
|
|
strides_[k] = 1; |
|
|
|
@@ -97,7 +97,7 @@ class StridedSliceGpuCommon { |
|
|
|
auto new_axis_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "new_axis_mask"); |
|
|
|
auto new_axis_mask = Dec2Bin(new_axis_mask_int); |
|
|
|
for (size_t l = 0; l < new_axis_mask.size(); l++) { |
|
|
|
if (new_axis_mask[l]) { |
|
|
|
if (new_axis_mask[l] && l < MAX_DIMS) { |
|
|
|
begin_[l] = 0; |
|
|
|
end_[l] = input_shape_[l]; |
|
|
|
strides_[l] = 1; |
|
|
|
@@ -107,7 +107,7 @@ class StridedSliceGpuCommon { |
|
|
|
auto shrink_axis_mask_int = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "shrink_axis_mask"); |
|
|
|
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); |
|
|
|
for (size_t m = 0; m < shrink_axis_mask.size(); m++) { |
|
|
|
if (shrink_axis_mask[m]) { |
|
|
|
if (shrink_axis_mask[m] && m < MAX_DIMS) { |
|
|
|
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1; |
|
|
|
strides_[m] = end_[m] > begin_[m] ? 1 : -1; |
|
|
|
} |
|
|
|
|