| @@ -36,6 +36,7 @@ class Methods: | |||||
| self.dict_barrier_counter = defaultdict(int) | self.dict_barrier_counter = defaultdict(int) | ||||
| self.dict_barrier_event = defaultdict(threading.Event) | self.dict_barrier_event = defaultdict(threading.Event) | ||||
| self.user_dict = defaultdict(partial(Future, False)) | self.user_dict = defaultdict(partial(Future, False)) | ||||
| self.bcast_dict = {} | |||||
| def connect(self): | def connect(self): | ||||
| """Method for checking connection success.""" | """Method for checking connection success.""" | ||||
| @@ -127,6 +128,23 @@ class Methods: | |||||
| future = self.user_dict[key] | future = self.user_dict[key] | ||||
| return future.get() | return future.get() | ||||
| def bcast_val(self, val, key, size): | |||||
| with self.lock: | |||||
| if key not in self.bcast_dict: | |||||
| self.bcast_dict[key] = [Future(False), size] | |||||
| arr = self.bcast_dict[key] | |||||
| if val is not None: | |||||
| arr[0].set(val) | |||||
| val = None | |||||
| else: | |||||
| val = arr[0].get() | |||||
| with self.lock: | |||||
| cnt = arr[1] - 1 | |||||
| arr[1] = cnt | |||||
| if cnt == 0: | |||||
| del self.bcast_dict[key] | |||||
| return val | |||||
| class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer): | ||||
| pass | pass | ||||
| @@ -142,7 +160,9 @@ def _start_server(py_server_port, queue): | |||||
| """ | """ | ||||
| try: | try: | ||||
| mm_server_port = create_mm_server("0.0.0.0", 0) | mm_server_port = create_mm_server("0.0.0.0", 0) | ||||
| server = ThreadXMLRPCServer(("0.0.0.0", py_server_port), logRequests=False) | |||||
| server = ThreadXMLRPCServer( | |||||
| ("0.0.0.0", py_server_port), logRequests=False, allow_none=True | |||||
| ) | |||||
| server.register_instance(Methods(mm_server_port)) | server.register_instance(Methods(mm_server_port)) | ||||
| _, py_server_port = server.server_address | _, py_server_port = server.server_address | ||||
| queue.put((py_server_port, mm_server_port)) | queue.put((py_server_port, mm_server_port)) | ||||
| @@ -185,13 +205,14 @@ class Client: | |||||
| self.master_ip = master_ip | self.master_ip = master_ip | ||||
| self.port = port | self.port = port | ||||
| self.connect() | self.connect() | ||||
| self.bcast_dict = defaultdict(lambda: 0) | |||||
| def connect(self): | def connect(self): | ||||
| """Check connection success.""" | """Check connection success.""" | ||||
| while True: | while True: | ||||
| try: | try: | ||||
| self.proxy = ServerProxy( | self.proxy = ServerProxy( | ||||
| "http://{}:{}".format(self.master_ip, self.port) | |||||
| "http://{}:{}".format(self.master_ip, self.port), allow_none=True | |||||
| ) | ) | ||||
| if self.proxy.connect(): | if self.proxy.connect(): | ||||
| break | break | ||||
| @@ -247,22 +268,17 @@ class Client: | |||||
| def user_set(self, key, val): | def user_set(self, key, val): | ||||
| """Set user defined key-value pairs across processes.""" | """Set user defined key-value pairs across processes.""" | ||||
| self.proxy.user_set(key, val) | |||||
| return self.proxy.user_set(key, val) | |||||
| def user_get(self, key): | def user_get(self, key): | ||||
| """Get user defined key-value pairs across processes.""" | """Get user defined key-value pairs across processes.""" | ||||
| return self.proxy.user_get(key) | return self.proxy.user_get(key) | ||||
| def bcast_val(self, val, key, size): | def bcast_val(self, val, key, size): | ||||
| if val is not None: | |||||
| self.user_set(key + "_sync", val) | |||||
| self.group_barrier(key, size) | |||||
| self.group_barrier(key, size) | |||||
| else: | |||||
| self.group_barrier(key, size) | |||||
| val = self.user_get(key + "_sync") | |||||
| self.group_barrier(key, size) | |||||
| return val | |||||
| idx = self.bcast_dict[key] + 1 | |||||
| self.bcast_dict[key] = idx | |||||
| key = key + "_bcast_" + str(idx) | |||||
| return self.proxy.bcast_val(val, key, size) | |||||
| def main(port=0, verbose=True): | def main(port=0, verbose=True): | ||||