# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import multiprocessing as mp
import threading
import time
from collections import defaultdict
from functools import partial
from socketserver import ThreadingMixIn
from xmlrpc.client import ServerProxy
from xmlrpc.server import SimpleXMLRPCServer
from ..core._imperative_rt.utils import create_mm_server
from ..utils.future import Future
class Methods:
r"""Distributed Server Method.
Used for exchange information between distributed nodes.
Args:
mm_server_port: multiple machine rpc server port.
"""
def __init__(self, mm_server_port):
self.lock = threading.Lock()
self.mm_server_port = mm_server_port
self.dict_is_grad = defaultdict(partial(Future, True))
self.dict_remote_tracer = defaultdict(partial(Future, True))
self.dict_pack_list = defaultdict(partial(Future, False))
self.dict_barrier_counter = defaultdict(int)
self.dict_barrier_event = defaultdict(threading.Event)
self.user_dict = defaultdict(partial(Future, False))
self.bcast_dict = {}
def connect(self):
r"""Method for checking connection success."""
return True
def get_mm_server_port(self):
r"""Get multiple machine rpc server port."""
return self.mm_server_port
def set_is_grad(self, key, is_grad):
r"""Mark send/recv need gradiants by key.
Args:
key: key to match send/recv op.
is_grad: whether this op need grad.
"""
with self.lock:
future = self.dict_is_grad[key]
future.set(is_grad)
return True
def check_is_grad(self, key):
r"""Check whether send/recv need gradiants.
Args:
key: key to match send/recv op.
"""
with self.lock:
future = self.dict_is_grad[key]
ret = future.get()
with self.lock:
del self.dict_is_grad[key]
return ret
def set_remote_tracer(self, key, tracer_set):
r"""Set tracer dict for tracing send/recv op.
Args:
key: key to match send/recv op.
tracer_set: valid tracer set.
"""
with self.lock:
future = self.dict_remote_tracer[key]
future.set(tracer_set)
return True
def check_remote_tracer(self, key):
r"""Get tracer dict for send/recv op.
Args:
key: key to match send/recv op.
"""
with self.lock:
future = self.dict_remote_tracer[key]
ret = future.get()
with self.lock:
del self.dict_remote_tracer[key]
return ret
def group_barrier(self, key, size):
r"""A barrier wait for all group member.
Args:
key: group key to match each other.
size: group size.
"""
with self.lock:
self.dict_barrier_counter[key] += 1
counter = self.dict_barrier_counter[key]
event = self.dict_barrier_event[key]
if counter == size:
del self.dict_barrier_counter[key]
del self.dict_barrier_event[key]
event.set()
else:
event.wait()
return True
def user_set(self, key, val):
r"""Set user defined key-value pairs across processes."""
with self.lock:
future = self.user_dict[key]
future.set(val)
return True
def user_get(self, key):
r"""Get user defined key-value pairs across processes."""
with self.lock:
future = self.user_dict[key]
return future.get()
def bcast_val(self, val, key, size):
with self.lock:
if key not in self.bcast_dict:
self.bcast_dict[key] = [Future(False), size]
arr = self.bcast_dict[key]
if val is not None:
arr[0].set(val)
val = None
else:
val = arr[0].get()
with self.lock:
cnt = arr[1] - 1
arr[1] = cnt
if cnt == 0:
del self.bcast_dict[key]
return val
def _del(self, key):
with self.lock:
del self.user_dict[key]
# thread safe function
def user_pop(self, key):
ret = self.user_get(key)
self._del(key)
return ret
class ThreadXMLRPCServer(ThreadingMixIn, SimpleXMLRPCServer):
pass
def _start_server(py_server_port, queue):
r"""Start python distributed server and multiple machine server.
Args:
py_server_port: python server port.
mm_server_port: multiple machine server port.
queue: server port will put in this queue, puts exception when process fails.
"""
try:
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(
("0.0.0.0", py_server_port), logRequests=False, allow_none=True
)
server.register_instance(Methods(mm_server_port))
_, py_server_port = server.server_address
queue.put((py_server_port, mm_server_port))
server.serve_forever()
except Exception as e:
queue.put(e)
[文档]class Server:
r"""Distributed Server for distributed training.
Should be running at master node.
Args:
port: python server port.
"""
def __init__(self, port=0):
q = mp.Queue()
self.proc = mp.Process(target=_start_server, args=(port, q), daemon=True)
self.proc.start()
ret = q.get()
if isinstance(ret, Exception):
raise ret
else:
self.py_server_port, self.mm_server_port = ret
def __del__(self):
self.proc.terminate()
[文档]class Client:
r"""Distributed Client for distributed training.
Args:
master_ip: ip address of master node.
port: port of server at master node.
"""
def __init__(self, master_ip, port):
self.master_ip = master_ip
self.port = port
self.connect()
self.bcast_dict = defaultdict(lambda: 0)
[文档] def connect(self):
r"""Check connection success."""
while True:
try:
self.proxy = ServerProxy(
"http://{}:{}".format(self.master_ip, self.port), allow_none=True
)
if self.proxy.connect():
break
except:
time.sleep(1)
[文档] def get_mm_server_port(self):
r"""Get multiple machine server port."""
return self.proxy.get_mm_server_port()
[文档] def set_is_grad(self, key, is_grad):
r"""Mark send/recv need gradiants by key.
Args:
key: key to match send/recv op.
is_grad: whether this op need grad.
"""
self.proxy.set_is_grad(key, is_grad)
[文档] def check_is_grad(self, key):
r"""Check whether send/recv need gradiants.
Args:
key: key to match send/recv op.
"""
return self.proxy.check_is_grad(key)
[文档] def set_remote_tracer(self, key, tracer_set):
r"""Set tracer dict for tracing send/recv op.
Args:
key: key to match send/recv op.
tracer_set: valid tracer set.
"""
self.proxy.set_remote_tracer(key, tracer_set)
[文档] def check_remote_tracer(self, key):
r"""Get tracer dict for send/recv op.
Args:
key: key to match send/recv op.
"""
return self.proxy.check_remote_tracer(key)
[文档] def group_barrier(self, key, size):
r"""A barrier wait for all group member.
Args:
key: group key to match each other.
size: group size.
"""
self.proxy.group_barrier(key, size)
[文档] def user_set(self, key, val):
r"""Set user defined key-value pairs across processes."""
return self.proxy.user_set(key, val)
[文档] def user_get(self, key):
r"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)
[文档] def user_pop(self, key):
r"""Get user defined key-value pairs and delete the resources when the get is done"""
return self.proxy.user_pop(key)
[文档] def bcast_val(self, val, key, size):
idx = self.bcast_dict[key] + 1
self.bcast_dict[key] = idx
key = key + "_bcast_" + str(idx)
return self.proxy.bcast_val(val, key, size)
def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0)
server = ThreadXMLRPCServer(("0.0.0.0", port), logRequests=verbose)
server.register_instance(Methods(mm_server_port))
_, port = server.server_address
print("serving on port", port)
server.serve_forever()
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("-p", "--port", type=int, default=0)
ap.add_argument("-v", "--verbose", type=bool, default=True)
args = ap.parse_args()
main(port=args.port, verbose=args.verbose)