|
|
|
@@ -35,16 +35,18 @@ log.setLevel(level=logging.ERROR) |
|
|
|
relu_test = Primitive('relu_test') |
|
|
|
|
|
|
|
|
|
|
|
def test_ops_f1(x, y): |
|
|
|
foo = relu_test(x) |
|
|
|
return foo |
|
|
|
def test_ops_f1(x): |
|
|
|
test = relu_test(x) |
|
|
|
return test |
|
|
|
|
|
|
|
|
|
|
|
# use method2: create instance outside function use an operator with parameters |
|
|
|
class Conv_test(Primitive): |
|
|
|
@prim_attr_register |
|
|
|
def __init__(self, stride=0, pad=1): |
|
|
|
print('in conv_test init', self.stride) |
|
|
|
self.stride = stride |
|
|
|
self.pad = pad |
|
|
|
print('in conv_test init', self.stride, self.pad) |
|
|
|
|
|
|
|
def __call__(self, x=0, y=1, z=2): |
|
|
|
pass |
|
|
|
@@ -65,7 +67,7 @@ class ResNet(nn.Cell): |
|
|
|
self.weight = Parameter(tensor, name="weight") |
|
|
|
self.conv = Conv_test(3, 5) |
|
|
|
|
|
|
|
def construct(self, x, y, train="train"): |
|
|
|
def construct(self, x, y): |
|
|
|
return x + y * self.weight + self.conv(x) |
|
|
|
|
|
|
|
def get_params(self): |
|
|
|
@@ -78,7 +80,7 @@ class SimpleNet(nn.Cell): |
|
|
|
self.weight = Parameter(tensor, name="weight") |
|
|
|
self.network = network |
|
|
|
|
|
|
|
def construct(self, x, y, train="train"): |
|
|
|
def construct(self, x, y): |
|
|
|
return self.network(x) + self.weight * y |
|
|
|
|
|
|
|
def get_params(self): |
|
|
|
@@ -106,7 +108,7 @@ class SimpleNet_1(nn.Cell): |
|
|
|
super(SimpleNet_1, self).__init__() |
|
|
|
self.conv = Conv_test(2, 3) |
|
|
|
|
|
|
|
def construct(self, x, y, train="train"): |
|
|
|
def construct(self, x, y): |
|
|
|
return self.conv(x, y) |
|
|
|
|
|
|
|
def get_params(self): |
|
|
|
|