# -*- coding: utf-8 -*-
import multiprocessing as mp
import threading
import time
from collections import defaultdict
from functools import partial
from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer
from ..core._imperative_rt.utils import create_mm_server
from ..utils.future import Future
class Methods:
r"""Distributed Server Method.
Used for exchange information between distributed nodes.
Args:
mm_server_port: multiple machine rpc server port.
"""
def __init__(self, mm_server_port):
self.lock = threading.Lock()
self.mm_server_port = mm_server_port
self.dict_is_grad = defaultdict(partial(Future, True))
self.dict_remote_tracer = defaultdict(partial(Future, True))
self.dict_pack_list = defaultdict(partial(Future, False))
self.dict_barrier_counter = defaultdict(int)
self.dict_barrier_event = defaultdict(threading.Event)
self.user_dict = defaultdict(partial(Future, False))
self.bcast_dict = {}
def connect(self):
r"""Method for checking connection success."""
return True
def get_mm_server_port(self):
r"""Get multiple machine rpc server port."""
return self.mm_server_port
def set_is_grad(self, key, is_grad):
r"""Mark send/recv need gradiants by key.
Args:
key: key to match send/recv op.
is_grad: whether this op need grad.
"""
with self.lock:
future = self.dict_is_grad[key]
future.set(is_grad)
return True
def check_is_grad(self, key):
r"""Check whether send/recv need gradiants.
Args:
key: key to match send/recv op.
"""
with self.lock:
future = self.dict_is_grad[key]
ret = future.get()
with self.lock:
del self.dict_is_grad[key]
return ret
def set_remote_tracer(self, key, tracer_set):
r"""Set tracer dict for tracing send/recv op.
Args:
key: key to match send/recv op.
tracer_set: valid tracer set.
"""
with self.lock:
future = self.dict_remote_tracer[key]
future.set(tracer_set)
return True
def check_remote_tracer(self, key):
r"""Get tracer dict for send/recv op.
Args:
key: key to match send/recv op.
"""
with self.lock:
future = self.dict_remote_tracer[key]
ret = future.get()
with self.lock:
del self.dict_remote_tracer[key]
return ret
def group_barrier(self, key, size):
r"""A barrier wait for all group member.
Args:
key: group key to match each other.
size: group size.
"""
with self.lock:
self.dict_barrier_counter[key] += 1
counter = self.dict_barrier_counter[key]
event = self.dict_barrier_event[key]
if counter == size:
del self.dict_barrier_counter[key]
del self.dict_barrier_event[key]
event.set()
else:
event.wait()
return True
def user_set(self, key, val):
r"""Set user defined key-value pairs across processes."""
with self.lock:
future = self.user_dict[key]
future.set(val)
return True
def user_get(self, key):
r"""Get user defined key-value pairs across processes."""
with self.lock:
future = self.user_dict[key]
return future.get()
def bcast_val(self, val, key, size):
with self.lock:
if key not in self.bcast_dict:
self.bcast_dict[key] = [Future(False), size]
arr = self.bcast_dict[key]
if val is not None:
arr[0].set(val)
val = None
else:
val = arr[0].get()
with self.lock:
cnt = arr[1] - 1
arr[1] = cnt
if cnt == 0:
del self.bcast_dict[key]
return val
def _del(self, key):
with self.lock:
del self.user_dict[key]
# thread safe function
def user_pop(self, key):
ret = self.user_get(key)
self._del(key)
return ret
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
def _start_server(py_server_port, queue):
r"""Start python distributed server and multiple machine server.
Args:
py_server_port: python server port.
mm_server_port: multiple machine server port.
queue: server port will put in this queue, puts exception when process fails.
"""
try:
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True
)
server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port))
server.serve_forever()
except Exception as e:
queue.put(e)
[docs]class Server:
r"""Distributed Server for distributed training.
Should be running at master node.
Args:
port: python server port.
"""
def __init__(self, port=0):
q = mp.Queue()
self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True)
self.proc.start()
ret = q.get()
if isinstance(ret, Exception):
raise ret
else:
self.py_server_port, self.mm_server_port = ret
def __del__(self):
self.proc.terminate()
class Client:
r"""Distributed Client for distributed training.
Args:
master_ip: ip address of master node.
port: port of server at master node.
"""
def __init__(self, master_ip, port):
self.master_ip = master_ip
self.port = port
self.connect()
self.bcast_dict = defaultdict(lambda: 0)
def connect(self):
r"""Check connection success."""
while True:
try:
self.proxy = ServerProxy(
"http://{}:{}".format(self.master_ip, self.port), allow_none=True
)
if self.proxy.connect():
break
except:
time.sleep(1)
def get_mm_server_port(self):
r"""Get multiple machine server port."""
while True:
try:
return self.proxy.get_mm_server_port()
except:
time.sleep(0.5)
def set_is_grad(self, key, is_grad):
r"""Mark send/recv need gradiants by key.
Args:
key: key to match send/recv op.
is_grad: whether this op need grad.
"""
self.proxy.set_is_grad(key, is_grad)
def check_is_grad(self, key):
r"""Check whether send/recv need gradiants.
Args:
key: key to match send/recv op.
"""
return self.proxy.check_is_grad(key)
def set_remote_tracer(self, key, tracer_set):
r"""Set tracer dict for tracing send/recv op.
Args:
key: key to match send/recv op.
tracer_set: valid tracer set.
"""
self.proxy.set_remote_tracer(key, tracer_set)
def check_remote_tracer(self, key):
r"""Get tracer dict for send/recv op.
Args:
key: key to match send/recv op.
"""
return self.proxy.check_remote_tracer(key)
def group_barrier(self, key, size):
r"""A barrier wait for all group member.
Args:
key: group key to match each other.
size: group size.
"""
# FIXME: group_barrier is not idempotent
while True:
try:
self.proxy.group_barrier(key, size)
return
except:
time.sleep(0.5)
def user_set(self, key, val):
r"""Set user defined key-value pairs across processes."""
return self.proxy.user_set(key, val)
def user_get(self, key):
r"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)
def user_pop(self, key):
r"""Get user defined key-value pairs and delete the resources when the get is done"""
return self.proxy.user_pop(key)
def bcast_val(self, val, key, size):
idx = self.bcast_dict[key] + 1
self.bcast_dict[key] = idx
key = key + "_bcast_" + str(idx)
return self.proxy.bcast_val(val, key, size)
def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose)
server.register_instance(Methods(mm_server_port))
_, port = server.server_address
print("serving on port", port)
server.serve_forever()
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("-p", "--port", type=int, default=0)
ap.add_argument("-v", "--verbose", type=bool, default=True)
args = ap.parse_args()
main(port=args.port, verbose=args.verbose)