|
|
|
@@ -121,6 +121,16 @@ class NetForFlatten0D(nn.Cell): |
|
|
|
return self.flatten(x) |
|
|
|
|
|
|
|
|
|
|
|
class NetForFlattenComposed(nn.Cell): |
|
|
|
# make flatten op together with other ops for testing flatten grad |
|
|
|
def __init__(self): |
|
|
|
super(NetForFlattenComposed, self).__init__() |
|
|
|
self.flatten = P.Flatten() |
|
|
|
|
|
|
|
def construct(self, x, y): |
|
|
|
return self.flatten(x+x) + y |
|
|
|
|
|
|
|
|
|
|
|
class ArgmaxNet(nn.Cell): |
|
|
|
def __init__(self): |
|
|
|
super(ArgmaxNet, self).__init__() |
|
|
|
@@ -695,7 +705,7 @@ test_case_nn_ops = [ |
|
|
|
('Flatten', { |
|
|
|
'block': P.Flatten(), |
|
|
|
'desc_inputs': [[128, 32, 32, 64]], |
|
|
|
'desc_bprop': [[128 * 32 * 8 * 16]]}), |
|
|
|
'desc_bprop': [[128, 65536]]}), |
|
|
|
('LogSoftmax', { |
|
|
|
'block': P.LogSoftmax(), |
|
|
|
'desc_inputs': [[64, 2]], |
|
|
|
@@ -893,6 +903,11 @@ test_case_nn_ops = [ |
|
|
|
'desc_inputs': [Tensor(np.ones([8]).astype(np.int32)), Tensor(np.ones([8, 3]).astype(np.int32))], |
|
|
|
'desc_bprop': [Tensor(np.ones([8, 3]).astype(np.int32))], |
|
|
|
'skip': ['backward']}), |
|
|
|
('Flatten_3', { |
|
|
|
'block': NetForFlattenComposed(), |
|
|
|
'desc_inputs': [Tensor(np.ones([2, 3, 4]).astype(np.int32)), Tensor(np.ones([2, 12]).astype(np.int32))], |
|
|
|
'desc_bprop': [Tensor(np.ones([2, 12]).astype(np.int32))], |
|
|
|
'skip': []}), |
|
|
|
('ArgmaxNet', { |
|
|
|
'block': ArgmaxNet(), |
|
|
|
'desc_inputs': [Tensor(np.array([[128, 32, 32, 64], [128, 32, 32, 64]]).astype(np.float16))], |
|
|
|
|