| @@ -459,27 +459,27 @@ class Parser: | |||
| logger.debug("ops info = %r", ops_info) | |||
| return ops_info | |||
| def analyze_super(self, father_class_node, subclass_instance): | |||
| def analyze_super(self, class_type_node, subclass_instance): | |||
| """Analyze super and return a class instance.""" | |||
| father_class = None | |||
| if father_class_node is None: | |||
| father_class = type(subclass_instance) | |||
| if isinstance(father_class_node, ast.Name): | |||
| father_class_name = getattr(father_class_node, 'id') | |||
| father_class = self.global_namespace[father_class_name] | |||
| if isinstance(father_class_node, ast.Attribute): | |||
| value = getattr(father_class_node, 'value') | |||
| attr = getattr(father_class_node, 'attr') | |||
| module_name = getattr(value, 'id') | |||
| father_class_module = self.global_namespace[module_name] | |||
| father_class = getattr(father_class_module, attr) | |||
| if father_class is None: | |||
| raise ValueError("When call 'super', the father class is None.") | |||
| if not isinstance(subclass_instance, father_class): | |||
| sub_class = type(subclass_instance) | |||
| if class_type_node is None: | |||
| return super(sub_class, subclass_instance) | |||
| if isinstance(class_type_node, ast.Name): | |||
| class_name = getattr(class_type_node, 'id') | |||
| elif isinstance(class_type_node, ast.Attribute): | |||
| class_name = getattr(class_type_node, 'attr') | |||
| else: | |||
| raise ValueError(f"When call 'super', the first arg should be a class type, " | |||
| f"but got {class_type_node.__class__.__name__}.") | |||
| target_father_class = None | |||
| for class_element in sub_class.mro(): | |||
| if class_element.__name__ == class_name: | |||
| target_father_class = class_element | |||
| break | |||
| if target_father_class is None: | |||
| raise ValueError("When call 'super', the second arg should be an instance of first arg.") | |||
| target_class_instance = super(father_class, subclass_instance) | |||
| return target_class_instance | |||
| return super(target_father_class, subclass_instance) | |||
| def get_location(self, node): | |||
| """ | |||
| @@ -58,6 +58,7 @@ class Cell: | |||
| >>> def construct(self, x): | |||
| >>> return self.relu(x) | |||
| """ | |||
| def __init__(self, auto_prefix=True, flags=None): | |||
| self._params = OrderedDict() | |||
| self._cells = OrderedDict() | |||
| @@ -888,6 +889,7 @@ class Cell: | |||
| for param in params: | |||
| param.set_param_ps(init_in_server) | |||
| class GraphKernel(Cell): | |||
| """ | |||
| Base class for GraphKernel. | |||
| @@ -904,6 +906,7 @@ class GraphKernel(Cell): | |||
| >>> def construct(self, x): | |||
| >>> return self.max(P.Fill()(P.DType()(x), P.Shape()(x), 0.0), x) | |||
| """ | |||
| def __init__(self, auto_prefix=True, pips=None): | |||
| super(GraphKernel, self).__init__(auto_prefix, pips) | |||
| class_name = self.__class__.__name__ | |||
| @@ -92,7 +92,7 @@ class Net(nn.Cell): | |||
| def test_single_super(): | |||
| single_net = SingleSubNet(2, 3) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| single_net(x, y) | |||
| @@ -100,7 +100,7 @@ def test_single_super(): | |||
| def test_mul_super(): | |||
| mul_net = MulSubNet(2, 3, 4) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| mul_net(x, y) | |||
| @@ -108,9 +108,41 @@ def test_mul_super(): | |||
| def test_super_cell(): | |||
| net = Net(2) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| context.set_context(mode=context.GRAPH_MODE) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| with pytest.raises(RuntimeError) as er: | |||
| net(x, y) | |||
| assert "Unsupported syntax 'Raise'" in str(er.value) | |||
| def test_single_super_in(): | |||
| class FatherNetIn(nn.Cell): | |||
| def __init__(self, x): | |||
| super(FatherNetIn, self).__init__(x) | |||
| self.x = x | |||
| def construct(self, x, y): | |||
| return self.x * x | |||
| def test_father(self, x): | |||
| return self.x + x | |||
| class SingleSubNetIN(FatherNetIn): | |||
| def __init__(self, x, z): | |||
| super(SingleSubNetIN, self).__init__(x) | |||
| self.z = z | |||
| def construct(self, x, y): | |||
| ret_father_construct = super().construct(x, y) | |||
| ret_father_test = super(SingleSubNetIN, self).test_father(x) | |||
| ret_father_x = super(SingleSubNetIN, self).x | |||
| ret_sub_z = self.z | |||
| return ret_father_construct, ret_father_test, ret_father_x, ret_sub_z | |||
| single_net_in = SingleSubNetIN(2, 3) | |||
| context.set_context(mode=context.GRAPH_MODE, save_graphs=True) | |||
| x = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| y = Tensor(np.ones([1, 2, 3], np.int32)) | |||
| single_net_in(x, y) | |||