2. fix test_graphdata_distributed.py failing randomlytags/v0.7.0-beta
| @@ -3217,12 +3217,14 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): | |||||
| while True: | while True: | ||||
| # Fetch index, block | # Fetch index, block | ||||
| try: | try: | ||||
| idx = idx_queue.get(timeout=10) | |||||
| # Index is generated very fast, so the timeout is very short | |||||
| idx = idx_queue.get(timeout=0.01) | |||||
| except KeyboardInterrupt: | except KeyboardInterrupt: | ||||
| raise Exception("Generator worker receives KeyboardInterrupt") | raise Exception("Generator worker receives KeyboardInterrupt") | ||||
| except queue.Empty: | except queue.Empty: | ||||
| if eof.is_set() or eoe.is_set(): | if eof.is_set() or eoe.is_set(): | ||||
| raise Exception("Generator worker receives queue.Empty") | |||||
| return | |||||
| # If eoe or eof is not set, continue to get data from idx_queue | |||||
| continue | continue | ||||
| if idx is None: | if idx is None: | ||||
| # When the queue is out of scope from master process, a None item can be fetched from the queue. | # When the queue is out of scope from master process, a None item can be fetched from the queue. | ||||
| @@ -3234,10 +3236,17 @@ def _generator_worker_loop(dataset, idx_queue, result_queue, eoe, eof): | |||||
| # Fetch data, any exception from __getitem__ will terminate worker and timeout master process | # Fetch data, any exception from __getitem__ will terminate worker and timeout master process | ||||
| result = dataset[idx] | result = dataset[idx] | ||||
| # Send data, block | # Send data, block | ||||
| try: | |||||
| result_queue.put(result) | |||||
| except KeyboardInterrupt: | |||||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||||
| while True: | |||||
| try: | |||||
| result_queue.put(result, timeout=5) | |||||
| except KeyboardInterrupt: | |||||
| raise Exception("Generator worker receives KeyboardInterrupt") | |||||
| except queue.Full: | |||||
| if eof.is_set(): | |||||
| return | |||||
| # If eof is not set, continue to put data to result_queue | |||||
| continue | |||||
| break | |||||
| del result, idx | del result, idx | ||||
| if eoe.is_set() and idx_queue.empty(): | if eoe.is_set() and idx_queue.empty(): | ||||
| return | return | ||||
| @@ -929,10 +929,10 @@ def check_split(method): | |||||
| def check_hostname(hostname): | def check_hostname(hostname): | ||||
| if len(hostname) > 255: | |||||
| if not hostname or len(hostname) > 255: | |||||
| return False | return False | ||||
| if hostname[-1] == ".": | if hostname[-1] == ".": | ||||
| hostname = hostname[:-1] # strip exactly one dot from the right, if present | |||||
| hostname = hostname[:-1] # strip exactly one dot from the right, if present | |||||
| allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE) | allowed = re.compile("(?!-)[A-Z\\d-]{1,63}(?<!-)$", re.IGNORECASE) | ||||
| return all(allowed.match(x) for x in hostname.split(".")) | return all(allowed.match(x) for x in hostname.split(".")) | ||||
| @@ -952,7 +952,7 @@ def check_gnn_graphdata(method): | |||||
| raise ValueError("The hostname is illegal") | raise ValueError("The hostname is illegal") | ||||
| type_check(working_mode, (str,), "working_mode") | type_check(working_mode, (str,), "working_mode") | ||||
| if working_mode not in {'local', 'client', 'server'}: | if working_mode not in {'local', 'client', 'server'}: | ||||
| raise ValueError("Invalid working mode") | |||||
| raise ValueError("Invalid working mode, please enter 'local', 'client' or 'server'") | |||||
| type_check(port, (int,), "port") | type_check(port, (int,), "port") | ||||
| check_value(port, (1024, 65535), "port") | check_value(port, (1024, 65535), "port") | ||||
| type_check(num_client, (int,), "num_client") | type_check(num_client, (int,), "num_client") | ||||
| @@ -23,12 +23,12 @@ from mindspore import log as logger | |||||
| DATASET_FILE = "../data/mindrecord/testGraphData/testdata" | DATASET_FILE = "../data/mindrecord/testGraphData/testdata" | ||||
| def graphdata_startserver(): | |||||
| def graphdata_startserver(server_port): | |||||
| """ | """ | ||||
| start graphdata server | start graphdata server | ||||
| """ | """ | ||||
| logger.info('test start server.\n') | logger.info('test start server.\n') | ||||
| ds.GraphData(DATASET_FILE, 1, 'server') | |||||
| ds.GraphData(DATASET_FILE, 1, 'server', port=server_port) | |||||
| class RandomBatchedSampler(ds.Sampler): | class RandomBatchedSampler(ds.Sampler): | ||||
| @@ -83,11 +83,13 @@ def test_graphdata_distributed(): | |||||
| """ | """ | ||||
| logger.info('test distributed.\n') | logger.info('test distributed.\n') | ||||
| p1 = Process(target=graphdata_startserver) | |||||
| server_port = random.randint(10000, 60000) | |||||
| p1 = Process(target=graphdata_startserver, args=(server_port,)) | |||||
| p1.start() | p1.start() | ||||
| time.sleep(2) | time.sleep(2) | ||||
| g = ds.GraphData(DATASET_FILE, 1, 'client') | |||||
| g = ds.GraphData(DATASET_FILE, 1, 'client', port=server_port) | |||||
| nodes = g.get_all_nodes(1) | nodes = g.get_all_nodes(1) | ||||
| assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] | assert nodes.tolist() == [101, 102, 103, 104, 105, 106, 107, 108, 109, 110] | ||||
| row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) | row_tensor = g.get_node_feature(nodes.tolist(), [1, 2, 3]) | ||||