|
|
|
@@ -19,6 +19,7 @@ import inspect as ins |
|
|
|
import os |
|
|
|
from functools import wraps |
|
|
|
from multiprocessing import cpu_count |
|
|
|
import numpy as np |
|
|
|
from mindspore._c_expression import typing |
|
|
|
from . import samplers |
|
|
|
from . import datasets |
|
|
|
@@ -1075,14 +1076,48 @@ def check_split(method): |
|
|
|
return new_method |
|
|
|
|
|
|
|
|
|
|
|
def check_list_or_ndarray(param, param_name): |
|
|
|
if (not isinstance(param, list)) and (not hasattr(param, 'tolist')): |
|
|
|
raise TypeError("Wrong input type for {0}, should be list, got {1}".format( |
|
|
|
def check_gnn_graphdata(method): |
|
|
|
"""check the input arguments of graphdata.""" |
|
|
|
|
|
|
|
@wraps(method) |
|
|
|
def new_method(*args, **kwargs): |
|
|
|
param_dict = make_param_dict(method, args, kwargs) |
|
|
|
|
|
|
|
# check dataset_file; required argument |
|
|
|
dataset_file = param_dict.get('dataset_file') |
|
|
|
if dataset_file is None: |
|
|
|
raise ValueError("dataset_file is not provided.") |
|
|
|
check_dataset_file(dataset_file) |
|
|
|
|
|
|
|
nreq_param_int = ['num_parallel_workers'] |
|
|
|
|
|
|
|
check_param_type(nreq_param_int, param_dict, int) |
|
|
|
|
|
|
|
return method(*args, **kwargs) |
|
|
|
|
|
|
|
return new_method |
|
|
|
|
|
|
|
|
|
|
|
def check_gnn_list_or_ndarray(param, param_name): |
|
|
|
"""Check if the input parameter is list or numpy.ndarray.""" |
|
|
|
|
|
|
|
if isinstance(param, list): |
|
|
|
for m in param: |
|
|
|
if not isinstance(m, int): |
|
|
|
raise TypeError( |
|
|
|
"Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m))) |
|
|
|
elif isinstance(param, np.ndarray): |
|
|
|
if not param.dtype == np.int32: |
|
|
|
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( |
|
|
|
param_name, param.dtype)) |
|
|
|
else: |
|
|
|
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( |
|
|
|
param_name, type(param))) |
|
|
|
|
|
|
|
|
|
|
|
def check_gnn_get_all_nodes(method): |
|
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_all_nodes` function.""" |
|
|
|
|
|
|
|
@wraps(method) |
|
|
|
def new_method(*args, **kwargs): |
|
|
|
param_dict = make_param_dict(method, args, kwargs) |
|
|
|
@@ -1103,7 +1138,7 @@ def check_gnn_get_all_neighbors(method): |
|
|
|
param_dict = make_param_dict(method, args, kwargs) |
|
|
|
|
|
|
|
# check node_list; required argument |
|
|
|
check_list_or_ndarray(param_dict.get("node_list"), 'node_list') |
|
|
|
check_gnn_list_or_ndarray(param_dict.get("node_list"), 'node_list') |
|
|
|
|
|
|
|
# check neighbor_type; required argument |
|
|
|
check_type(param_dict.get("neighbor_type"), 'neighbor_type', int) |
|
|
|
@@ -1113,15 +1148,16 @@ def check_gnn_get_all_neighbors(method): |
|
|
|
return new_method |
|
|
|
|
|
|
|
|
|
|
|
def check_aligned_list(param, param_name): |
|
|
|
def check_aligned_list(param, param_name, membor_type): |
|
|
|
"""Check whether the structure of each member of the list is the same.""" |
|
|
|
|
|
|
|
if not isinstance(param, list): |
|
|
|
raise TypeError("Parameter {0} is not a list".format(param_name)) |
|
|
|
membor_have_list = None |
|
|
|
list_len = None |
|
|
|
for membor in param: |
|
|
|
if isinstance(membor, list): |
|
|
|
check_aligned_list(membor, param_name) |
|
|
|
check_aligned_list(membor, param_name, membor_type) |
|
|
|
if membor_have_list not in (None, True): |
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format( |
|
|
|
param_name)) |
|
|
|
@@ -1131,6 +1167,9 @@ def check_aligned_list(param, param_name): |
|
|
|
membor_have_list = True |
|
|
|
list_len = len(membor) |
|
|
|
else: |
|
|
|
if not isinstance(membor, membor_type): |
|
|
|
raise TypeError("Each membor in {0} should be of type int. Got {1}.".format( |
|
|
|
param_name, type(membor))) |
|
|
|
if membor_have_list not in (None, False): |
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format( |
|
|
|
param_name)) |
|
|
|
@@ -1139,18 +1178,26 @@ def check_aligned_list(param, param_name): |
|
|
|
|
|
|
|
def check_gnn_get_node_feature(method): |
|
|
|
"""A wrapper that wrap a parameter checker to the GNN `get_node_feature` function.""" |
|
|
|
|
|
|
|
@wraps(method) |
|
|
|
def new_method(*args, **kwargs): |
|
|
|
param_dict = make_param_dict(method, args, kwargs) |
|
|
|
|
|
|
|
# check node_list; required argument |
|
|
|
node_list = param_dict.get("node_list") |
|
|
|
check_list_or_ndarray(node_list, 'node_list') |
|
|
|
if isinstance(node_list, list): |
|
|
|
check_aligned_list(node_list, 'node_list') |
|
|
|
check_aligned_list(node_list, 'node_list', int) |
|
|
|
elif isinstance(node_list, np.ndarray): |
|
|
|
if not node_list.dtype == np.int32: |
|
|
|
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format( |
|
|
|
node_list, node_list.dtype)) |
|
|
|
else: |
|
|
|
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format( |
|
|
|
'node_list', type(node_list))) |
|
|
|
|
|
|
|
# check feature_types; required argument |
|
|
|
check_list_or_ndarray(param_dict.get("feature_types"), 'feature_types') |
|
|
|
check_gnn_list_or_ndarray(param_dict.get( |
|
|
|
"feature_types"), 'feature_types') |
|
|
|
|
|
|
|
return method(*args, **kwargs) |
|
|
|
|
|
|
|
|