|
|
|
@@ -14,6 +14,7 @@ |
|
|
|
# ============================================================================ |
|
|
|
|
|
|
|
"""image_ops""" |
|
|
|
from ... import context |
|
|
|
from ..._checkparam import Validator as validator |
|
|
|
from ..._checkparam import Rel |
|
|
|
from ...common import dtype as mstype |
|
|
|
@@ -84,6 +85,7 @@ class CropAndResize(PrimitiveWithInfer): |
|
|
|
self.method = method |
|
|
|
validator.check_value_type("extrapolation_value", extrapolation_value, [float], self.name) |
|
|
|
self.extrapolation_value = extrapolation_value |
|
|
|
self.is_ge = context.get_context("enable_ge") |
|
|
|
|
|
|
|
def __infer__(self, x, boxes, box_index, crop_size): |
|
|
|
# get shape |
|
|
|
@@ -124,6 +126,9 @@ class CropAndResize(PrimitiveWithInfer): |
|
|
|
crop_height = crop_size_value[0] |
|
|
|
crop_width = crop_size_value[1] |
|
|
|
depth = x_shape[3] |
|
|
|
return {'shape': (num_boxes, crop_height, crop_width, depth), |
|
|
|
out_shape = (num_boxes, crop_height, crop_width, depth) |
|
|
|
if self.is_ge: |
|
|
|
out_shape = (num_boxes, x_shape[1], crop_height, crop_width) |
|
|
|
return {'shape': out_shape, |
|
|
|
'dtype': mstype.float32, |
|
|
|
'value': None} |