|
|
|
@@ -2695,6 +2695,7 @@ class ROIAlign(PrimitiveWithInfer): |
|
|
|
feature map coordinates. Suppose the height of a RoI is `ori_h` in the raw image and `fea_h` in the |
|
|
|
input feature map, the `spatial_scale` should be `fea_h / ori_h`. |
|
|
|
sample_num (int): Number of sampling points. Default: 2. |
|
|
|
roi_end_mode (int): Number must be 0 or 1. Default: 1. |
|
|
|
|
|
|
|
Inputs: |
|
|
|
- **features** (Tensor) - The input features, whose shape should be `(N, C, H, W)`. |
|
|
|
@@ -2717,16 +2718,19 @@ class ROIAlign(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
|
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2): |
|
|
|
def __init__(self, pooled_height, pooled_width, spatial_scale, sample_num=2, roi_end_mode=1): |
|
|
|
"""init ROIAlign""" |
|
|
|
validator.check_value_type("pooled_height", pooled_height, [int], self.name) |
|
|
|
validator.check_value_type("pooled_width", pooled_width, [int], self.name) |
|
|
|
validator.check_value_type("spatial_scale", spatial_scale, [float], self.name) |
|
|
|
validator.check_value_type("sample_num", sample_num, [int], self.name) |
|
|
|
validator.check_value_type("roi_end_mode", roi_end_mode, [int], self.name) |
|
|
|
validator.check_int_range("roi_end_mode", roi_end_mode, 0, 1, Rel.INC_BOTH, self.name) |
|
|
|
self.pooled_height = pooled_height |
|
|
|
self.pooled_width = pooled_width |
|
|
|
self.spatial_scale = spatial_scale |
|
|
|
self.sample_num = sample_num |
|
|
|
self.roi_end_mode = roi_end_mode |
|
|
|
|
|
|
|
def infer_shape(self, inputs_shape, rois_shape): |
|
|
|
return [rois_shape[0], inputs_shape[1], self.pooled_height, self.pooled_width] |
|
|
|
|