|
|
|
@@ -431,7 +431,7 @@ class Reshape(PrimitiveWithInfer): |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class Shape(Primitive): |
|
|
|
class Shape(PrimitiveWithInfer): |
|
|
|
""" |
|
|
|
Returns the shape of input tensor. |
|
|
|
|
|
|
|
@@ -453,6 +453,13 @@ class Shape(Primitive): |
|
|
|
def __init__(self): |
|
|
|
"""Initialize Shape""" |
|
|
|
|
|
|
|
def __infer__(self, x): |
|
|
|
validator.check_subclass("input_x", x['dtype'], mstype.tensor, self.name) |
|
|
|
out = {'shape': (), |
|
|
|
'dtype': mstype.tuple_, |
|
|
|
'value': tuple(x['shape'])} |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
class DynamicShape(Primitive): |
|
|
|
""" |
|
|
|
|