|
|
|
@@ -149,12 +149,22 @@ class StridedSliceGradGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask")); |
|
|
|
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); |
|
|
|
for (size_t l = 0; l < shrink_axis_mask.size(); l++) { |
|
|
|
if (shrink_axis_mask[l]) { |
|
|
|
end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; |
|
|
|
strides_[l] = end_[l] > begin_[l] ? 1 : -1; |
|
|
|
auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask")); |
|
|
|
auto new_axis_mask = Dec2Bin(new_axis_mask_int); |
|
|
|
for (size_t l = 0; l < new_axis_mask.size(); l++) { |
|
|
|
if (new_axis_mask[l]) { |
|
|
|
begin_[l] = 0; |
|
|
|
end_[l] = input_shape_[l]; |
|
|
|
strides_[l] = 1; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask")); |
|
|
|
auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); |
|
|
|
for (size_t m = 0; m < shrink_axis_mask.size(); m++) { |
|
|
|
if (shrink_axis_mask[m]) { |
|
|
|
end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1; |
|
|
|
strides_[m] = end_[m] > begin_[m] ? 1 : -1; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|