| @@ -36,6 +36,7 @@ class Methods: | |||
| self.dict_barrier_counter = defaultdict(int) | |||
| self.dict_barrier_event = defaultdict(threading.Event) | |||
| self.user_dict = defaultdict(partial(Future, False)) | |||
| self.bcast_dict = {} | |||
| def connect(self): | |||
| """Method for checking connection success.""" | |||
| @@ -127,6 +128,23 @@ class Methods: | |||
| future = self.user_dict[key] | |||
| return future.get() | |||
| def bcast_val(self, val, key, size): | |||
| with self.lock: | |||
| if key not in self.bcast_dict: | |||
| self.bcast_dict[key] = [Future(False), size] | |||
| arr = self.bcast_dict[key] | |||
| if val is not None: | |||
| arr[0].set(val) | |||
| val = None | |||
| else: | |||
| val = arr[0].get() | |||
| with self.lock: | |||
| cnt = arr[1] - 1 | |||
| arr[1] = cnt | |||
| if cnt == 0: | |||
| del self.bcast_dict[key] | |||
| return val | |||
| class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | |||
| pass | |||
| @@ -142,7 +160,9 @@ def _start_server(py_server_port, queue): | |||
| """ | |||
| try: | |||
| mm_server_port = create_mm_server("0.0.0.0", 0) | |||
| server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||
| server = ThreadXMLRPCServer( | |||
| ("0.0.0.0", py_server_port), logRequests=False, allow_none=True | |||
| ) | |||
| server.register_instance(Methods(mm_server_port)) | |||
| _, py_server_port = server.server_address | |||
| queue.put((py_server_port, mm_server_port)) | |||
| @@ -185,13 +205,14 @@ class Client: | |||
| self.master_ip = master_ip | |||
| self.port = port | |||
| self.connect() | |||
| self.bcast_dict = defaultdict(lambda: 0) | |||
| def connect(self): | |||
| """Check connection success.""" | |||
| while True: | |||
| try: | |||
| self.proxy = ServerProxy( | |||
| "http://{}:{}".format(self.master_ip, self.port) | |||
| "http://{}:{}".format(self.master_ip, self.port), allow_none=True | |||
| ) | |||
| if self.proxy.connect(): | |||
| break | |||
| @@ -247,22 +268,17 @@ class Client: | |||
| def user_set(self, key, val): | |||
| """Set user defined key-value pairs across processes.""" | |||
| self.proxy.user_set(key, val) | |||
| return self.proxy.user_set(key, val) | |||
| def user_get(self, key): | |||
| """Get user defined key-value pairs across processes.""" | |||
| return self.proxy.user_get(key) | |||
| def bcast_val(self, val, key, size): | |||
| if val is not None: | |||
| self.user_set(key + "_sync", val) | |||
| self.group_barrier(key, size) | |||
| self.group_barrier(key, size) | |||
| else: | |||
| self.group_barrier(key, size) | |||
| val = self.user_get(key + "_sync") | |||
| self.group_barrier(key, size) | |||
| return val | |||
| idx = self.bcast_dict[key] + 1 | |||
| self.bcast_dict[key] = idx | |||
| key = key + "_bcast_" + str(idx) | |||
| return self.proxy.bcast_val(val, key, size) | |||
| def main(port=0, verbose=True): | |||