From: @wilfchen Reviewed-by: @limingqi107,@kisnwang Signed-off-by: @kisnwangtags/v1.1.0
| @@ -143,12 +143,22 @@ class StridedSliceGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| } | } | ||||
| auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask")); | |||||
| auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); | |||||
| for (size_t l = 0; l < shrink_axis_mask.size(); l++) { | |||||
| if (shrink_axis_mask[l]) { | |||||
| end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; | |||||
| strides_[l] = end_[l] > begin_[l] ? 1 : -1; | |||||
| auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask")); | |||||
| auto new_axis_mask = Dec2Bin(new_axis_mask_int); | |||||
| for (size_t l = 0; l < new_axis_mask.size(); l++) { | |||||
| if (new_axis_mask[l]) { | |||||
| begin_[l] = 0; | |||||
| end_[l] = input_shape_[l]; | |||||
| strides_[l] = 1; | |||||
| } | |||||
| } | |||||
| auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask")); | |||||
| auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); | |||||
| for (size_t m = 0; m < shrink_axis_mask.size(); m++) { | |||||
| if (shrink_axis_mask[m]) { | |||||
| end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1; | |||||
| strides_[m] = end_[m] > begin_[m] ? 1 : -1; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -149,12 +149,22 @@ class StridedSliceGradGpuKernel : public GpuKernel { | |||||
| } | } | ||||
| } | } | ||||
| auto shrink_axis_mask_str = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask")); | |||||
| auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_str); | |||||
| for (size_t l = 0; l < shrink_axis_mask.size(); l++) { | |||||
| if (shrink_axis_mask[l]) { | |||||
| end_[l] = end_[l] > begin_[l] ? begin_[l] + 1 : begin_[l] - 1; | |||||
| strides_[l] = end_[l] > begin_[l] ? 1 : -1; | |||||
| auto new_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "new_axis_mask")); | |||||
| auto new_axis_mask = Dec2Bin(new_axis_mask_int); | |||||
| for (size_t l = 0; l < new_axis_mask.size(); l++) { | |||||
| if (new_axis_mask[l]) { | |||||
| begin_[l] = 0; | |||||
| end_[l] = input_shape_[l]; | |||||
| strides_[l] = 1; | |||||
| } | |||||
| } | |||||
| auto shrink_axis_mask_int = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "shrink_axis_mask")); | |||||
| auto shrink_axis_mask = Dec2Bin(shrink_axis_mask_int); | |||||
| for (size_t m = 0; m < shrink_axis_mask.size(); m++) { | |||||
| if (shrink_axis_mask[m]) { | |||||
| end_[m] = end_[m] > begin_[m] ? begin_[m] + 1 : begin_[m] - 1; | |||||
| strides_[m] = end_[m] > begin_[m] ? 1 : -1; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||