""" library to take autodiff and execute a computation graph """ from __future__ import absolute_import from .BatchNorm import Batch_NormalizationOp import numpy as np from scipy.sparse import spmatrix, coo_matrix from .. import ndarray from .._base import DNNL_LIB from ..cpu_links import array_set as cpu_array_set from .Variable import PlaceholderOp # add for optimizer from ..dataloader import DataloaderOp, GNNDataLoaderOp from .AllReduceCommunicate import AllReduceCommunicateOp from .ParameterServerCommunicate import ParameterServerCommunicateOp, ParameterServerSparsePullOp, parameterServerSparsePull_op from .AddElewise import add_op from .DataTransfer import DataH2DOp, DataD2HOp, DataD2HSparseOp from .EmbeddingLookUp import EmbeddingLookUp, EmbeddingLookUp_Gradient from ..optimizer import OptimizerOp from . import OnesLike from ..stream import create_stream_handle, Event from ..context import get_current_context, get_launch_config_by_traverse_nodes, assign_context_by_traverse_nodes, DeviceGroup from .PipelineSend import PipelineSendOp from .PipelineReceive import PipelineReceiveOp from .Dropout import DropoutOp from .LayerNorm import Layer_NormalizationOp from operator import add from functools import reduce import ctypes import os from time import time def path_to_lib(name): curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) lib_path = os.path.join(curr_path, '../../../build/lib/') return os.path.join(lib_path, name) def wrapped_mpi_nccl_init(init_nccl=True, devices=None): from ..communicator.mpi_nccl_comm import mpi_communicator global mpi_comm global nccl_comm if 'mpi_comm' not in globals(): mpi_comm = mpi_communicator(devices=devices) if 'nccl_comm' not in globals(): nccl_comm = mpi_comm.ncclInit() if init_nccl else None return nccl_comm def new_group_comm(devices_context=None): assert 'mpi_comm' in globals() global mpi_comm if devices_context is None: comm = mpi_comm.ncclInit() else: comm = mpi_comm.ncclGroupInit(devices_context) return comm def get_nccl_communicate(): global nccl_comm return nccl_comm def get_worker_communicate(): global ps_comm return ps_comm def worker_init(): global ps_comm ll = ctypes.cdll.LoadLibrary ps_comm = ll(path_to_lib("libps.so")) ps_comm.Init() def worker_finish(): ps_comm.Finalize() def server_init(): global ps_comm ll = ctypes.cdll.LoadLibrary ps_comm = ll(path_to_lib("libps.so")) ps_comm.Init() ps_comm.StartServer() def server_finish(): ps_comm.Finalize() def scheduler_init(): global ps_comm ll = ctypes.cdll.LoadLibrary ps_comm = ll(path_to_lib("libps.so")) ps_comm.Init() def scheduler_finish(): ps_comm.Finalize() class HetuConfig(object): __slots__ = [ 'eval_node_list', 'train_name', 'val_name', 'context', 'seed', 'np_rand', 'comm_mode', 'node_strategy', 'context_launch', 'ps_comm', 'nccl_comm', 'local_rank', 'rank', 'nrank', 'p2p_stream', 'comp_stream', 'nccl_stream', 'h2d_stream', 'd2h_stream', 'h2d_ops', 'd2h_ops', 'ps_map', 'infer_ps_map', 'dataloader_ops', 'use_sparse_pull', 'cstable_policy', 'inference', 'enable_lazy', 'bsp', 'prefetch', 'cache_bound', 'log_path', 'my_eval_nodes', 'param_allreduce_group', 'placeholder_to_arr_map' ] def __init__( self, eval_node_list, train_name, val_name, ctx=None, seed=None, comm_mode=None, use_sparse_pull=True, cstable_policy=None, bsp=False, prefetch=True, enable_lazy=True, cache_bound=100, log_path=None, ): ''' context: default device context comm_mode: communication mode, should be one of the following None -> Single GPU PS -> Parameter Server AllRedeuce -> MPI AllReduce Hybrid -> Parameter Server for Sparse Parameter and MPI AllReduce for Dense Parameter ''' self.eval_node_list = eval_node_list self.train_name = train_name self.val_name = val_name # check context if ctx is None: ctx = get_current_context() assert ctx, 'Default context should be determined.' self.comm_mode = comm_mode self.node_strategy = {} local_gpu_devices = None context_launch = isinstance(ctx, DeviceGroup) self.context_launch = context_launch if context_launch: # with context usage launchMPI, launchPS, self.node_strategy, devices = get_launch_config_by_traverse_nodes( eval_node_list, ctx) local_gpu_devices = sorted( [dev.device_id for dev in devices if dev.local and ndarray.is_gpu_ctx(dev)]) if not launchMPI and not launchPS: self.comm_mode = None elif launchMPI and not launchPS: self.comm_mode = 'AllReduce' elif not launchMPI and launchPS: self.comm_mode = 'PS' else: self.comm_mode = 'Hybrid' # in pipeline or model parallel we have to initialize another p2p stream init_p2p_stream = len(devices) != len(ctx) # variables initialization self.seed = seed if seed else np.int64(time()) self.np_rand = np.random.RandomState(self.seed) # get attribute of communication mode self.ps_comm = None self.nccl_comm = None self.local_rank = None self.rank = None self.nrank = None ps_nrank = None if self.comm_mode == 'PS' or self.comm_mode == 'Hybrid': worker_init() self.ps_comm = get_worker_communicate() ps_rank = int(self.ps_comm.rank()) ps_nrank = int( os.environ['DMLC_NUM_WORKER']) if 'DMLC_NUM_WORKER' in os.environ else 1 if self.comm_mode == "Hybrid" or self.comm_mode == "AllReduce": self.nccl_comm = wrapped_mpi_nccl_init(devices=local_gpu_devices) elif context_launch: self.nccl_comm = wrapped_mpi_nccl_init( init_nccl=init_p2p_stream, devices=local_gpu_devices) if self.nccl_comm is not None: self.local_rank = self.nccl_comm.local_rank device_id = self.nccl_comm.dev_id self.rank = self.nccl_comm.rank self.nrank = self.nccl_comm.nrank if ps_nrank: assert ps_nrank == self.nrank elif self.comm_mode == 'PS': self.rank = ps_rank self.nrank = ps_nrank if context_launch: global mpi_comm self.local_rank = mpi_comm.local_rank device_id = mpi_comm.dev_id self.my_eval_nodes = eval_node_list self.p2p_stream = None self.param_allreduce_group = {} if context_launch: # comm_mode is None <=> only 1 model parallel instance self.context = ndarray.gpu(device_id) self.p2p_stream = create_stream_handle( self.context) if init_p2p_stream else None self.my_eval_nodes, trainable_params, has_send_recv = assign_context_by_traverse_nodes( eval_node_list, self.context, self.nccl_comm, self.p2p_stream) if (self.comm_mode == "Hybrid" or self.comm_mode == "AllReduce") and has_send_recv: # here we need to use group communicator to implement allreduce, # since not all processes use the same group groups = set([n.raw_ctx for n in trainable_params]) temp_group_comms = {} for group in groups: temp_group_comms[group] = new_group_comm(group) self.param_allreduce_group = { n: temp_group_comms[n.raw_ctx] for n in trainable_params} else: self.context = ctx on_gpu = ndarray.is_gpu_ctx(self.context) self.nccl_stream = None if self.comm_mode == "Hybrid" or self.comm_mode == "AllReduce": if on_gpu: self.nccl_stream = create_stream_handle(self.context) self.nccl_comm = get_nccl_communicate() # define streams self.comp_stream = create_stream_handle( self.context) if on_gpu else None self.h2d_stream = create_stream_handle( self.context) if on_gpu else None self.d2h_stream = create_stream_handle( self.context) if on_gpu else None self.use_sparse_pull = use_sparse_pull if self.comm_mode == 'PS' or self.comm_mode == "Hybrid" else False self.cstable_policy = cstable_policy if self.comm_mode == 'PS' or self.comm_mode == "Hybrid" else None self.prefetch = prefetch if self.comm_mode == 'PS' or self.comm_mode == 'Hybrid' else False if self.cstable_policy is not None: self.cstable_policy = self.cstable_policy.upper() self.use_sparse_pull = False self.h2d_ops = {} self.d2h_ops = {} self.ps_map = {} self.infer_ps_map = {} self.enable_lazy = False and enable_lazy # now we don't use lazy self.bsp = bsp self.cache_bound = int(cache_bound) self.log_path = log_path if log_path is not None and (self.comm_mode == 'PS' or self.comm_mode == "Hybrid"): assert os.path.isdir( log_path), 'Need to specify a work directory to save logs.' self.ps_comm.startRecord(ctypes.c_char_p(bytes(log_path, 'utf-8'))) self.placeholder_to_arr_map = dict() topo_sort_with_hook(self.my_eval_nodes, self) class Executor(object): """Executor computes values for given set of nodes in computation graph.""" def __init__(self, eval_node_dict, config=None, **kargs): """ Parameters ---------- eval_node_dict: dict of list of nodes whose values need to be computed. """ if not isinstance(eval_node_dict, dict): eval_node_dict = {'default': eval_node_dict} train_name, val_name = None, None for k, v in eval_node_dict.items(): if any([isinstance(node, OptimizerOp) for node in v]): # get the last subexecutor containing optimizer as train for ps op train_name = k else: # get the last subexecutor not containing optimizer as val for ps op val_name = k all_eval_nodes = list(set(reduce(add, eval_node_dict.values()))) if config is None: config = HetuConfig(eval_node_list=all_eval_nodes, train_name=train_name, val_name=val_name, **kargs) assert isinstance( config, HetuConfig), 'Config type %s invalid.' % str(type(config)) self.eval_node_dict = eval_node_dict self.config = config self.subexecutor = {k: SubExecutor( k, v, config) for k, v in eval_node_dict.items()} self.topo_order = find_topo_sort(config.my_eval_nodes) self.param_nodes = [node for node in self.topo_order if isinstance( node, PlaceholderOp) and node.trainable] self.comm_mode = self.config.comm_mode self.ps_comm = self.config.ps_comm self.local_rank = self.config.local_rank self.rank = self.config.rank def run(self, name='default', eval_node_list={}, feed_dict={}, convert_to_numpy_ret_vals=False): return self.subexecutor[name].run(eval_node_list, feed_dict, convert_to_numpy_ret_vals) @property def batch_num(self): assert len( self.subexecutor) == 1, 'Batch num should be used with only 1 subexecutor.' return list(self.subexecutor.values())[0].batch_num def get_batch_num(self, name='default'): return self.subexecutor[name].batch_num def save(self, file_path): assert os.path.isdir( file_path), 'Need to specify a work directory to save parameters.' if self.comm_mode in (None, 'AllReduce'): # when using allreduce, users need to specify the worker whose rank equals 0 to save for node in self.param_nodes: np.save(os.path.join(file_path, node.name + '.npy'), self.config.placeholder_to_arr_map[node].asnumpy()) else: self.ps_comm.BarrierWorker() if self.config.rank == 0: for node in self.param_nodes: if node.is_embed or self.comm_mode == 'PS': node.event.sync() nodeid = ctypes.c_int(node.id) self.ps_comm.SaveParam( nodeid, ctypes.c_char_p(bytes(file_path, 'utf-8'))) self.ps_comm.Wait(nodeid) else: np.save(os.path.join(file_path, node.name + '.npy'), self.config.placeholder_to_arr_map[node].asnumpy()) self.ps_comm.BarrierWorker() def load(self, file_path): assert os.path.isdir( file_path), 'Need to specify a work directory to load parameters.' if self.comm_mode in (None, 'AllReduce'): for node in self.param_nodes: self.config.placeholder_to_arr_map[node][:] = np.load( os.path.join(file_path, node.name + '.npy')) else: self.ps_comm.BarrierWorker() if self.config.rank == 0: for node in self.param_nodes: if node.is_embed or self.comm_mode == 'PS': node.event.sync() nodeid = ctypes.c_int(node.id) self.ps_comm.LoadParam( nodeid, ctypes.c_char_p(bytes(file_path, 'utf-8'))) node.event.update() self.ps_comm.BarrierWorker() for node in self.topo_order: if isinstance(node, PlaceholderOp) and node.trainable and not node.is_embed: if self.comm_mode == 'PS': node.event.sync() nodeid = ctypes.c_int(node.id) self.ps_comm.Pull( nodeid, self.config.ps_map[node].handle) node.event.update() else: self.config.placeholder_to_arr_map[node][:] = np.load( os.path.join(file_path, node.name + '.npy')) elif isinstance(node, EmbeddingLookUp) and self.config.prefetch: node.event.sync() nodeid = ctypes.c_int(node.inputs[0].id) self.ps_comm.SparsePull(nodeid, node.inputs[1].get_next_arr( self.name).handle, self.config.ps_map[node.inputs[0]].handle) node.event.update() self.ps_comm.BarrierWorker() def recordLoads(self): for node in self.config.ps_map: node.event.sync() self.ps_comm.getLoads() def __del__(self): if self.config.comp_stream is not None: self.config.comp_stream.sync() if self.config.h2d_stream is not None: self.config.h2d_stream.sync() if self.config.d2h_stream is not None: self.config.d2h_stream.sync() if self.config.nccl_stream is not None: self.config.nccl_stream.sync() for node in self.param_nodes: if node.event: node.event.sync() if self.comm_mode in ('PS', 'Hybrid'): worker_finish() class SubExecutor(object): def __init__(self, name, eval_node_list, config): """ Parameters ---------- eval_node_list: list of nodes whose values need to be computed. topo_order: list of nodes in topological order node_to_shape_map: dict from node to shape of the node node_to_arr_map: dict from node to ndarray.NDArray allocated for node feed_shapes: shapes of feed_dict from last run(...) """ self.name = name self.eval_node_list = eval_node_list self.config = config inference = not any([isinstance(node, OptimizerOp) for node in eval_node_list]) self.inference = inference if config.p2p_stream: self.run_results_indices = [eval_node_list.index( node) if node in eval_node_list else -1 for node in config.my_eval_nodes] self.eval_node_list = config.my_eval_nodes self.global_eval_nodes = eval_node_list if inference == False: self.topo_order = find_topo_sort(self.eval_node_list) else: # in inference phase if self.config.use_sparse_pull == True or self.config.cstable_policy is not None: # insert ps_sparse_pull_op self.topo_order = find_topo_sort_inference(self.eval_node_list) # fetch sparse parameter fetch_sparse_parameter_value(self.topo_order, self.config) else: self.topo_order = find_topo_sort(self.eval_node_list) # main structures, nodes' shapes and arrays self.node_to_shape_map = {} self.node_to_arr_map = {} # inherit from configurations self.comm_mode = self.config.comm_mode self.ps_comm = self.config.ps_comm self.nccl_comm = self.config.nccl_comm self.comp_stream = self.config.comp_stream self.h2d_stream = self.config.h2d_stream self.d2h_stream = self.config.d2h_stream self.nccl_stream = self.config.nccl_stream self.param_psval_map = self.config.infer_ps_map if self.inference else self.config.ps_map self.use_sparse_pull = self.config.use_sparse_pull self.cstable_policy = self.config.cstable_policy self.use_p2p = self.config.p2p_stream is not None # assisting structures, improve performance self.need_feed_nodes = [] self.param_nodes = [] self.dataloader_nodes = [] self.computing_nodes = [] for node in self.topo_order: if isinstance(node, DataloaderOp) or isinstance(node, GNNDataLoaderOp): self.dataloader_nodes.append(node) elif isinstance(node, PlaceholderOp): if node.shape is None: self.need_feed_nodes.append(node) elif node.trainable: self.param_nodes.append(node) elif not ((self.use_sparse_pull or self.cstable_policy) and isinstance(node, EmbeddingLookUp) and self.config.prefetch): self.computing_nodes.append(node) self.batch_num = set([node.get_batch_num(self.name) for node in self.dataloader_nodes]) assert len(self.batch_num) <= 1, 'Batch num not conform.' self.batch_num = None if len( self.batch_num) == 0 else self.batch_num.pop() self.init_need_allocation = (self.need_feed_nodes == []) and ( self.dataloader_nodes == []) def update_executor(self, eval_node_list): self.eval_node_list = eval_node_list inference = not any([isinstance(node, OptimizerOp) for node in eval_node_list]) self.inference = inference if self.config.p2p_stream and self.inference == True: raise NotImplementedError if inference == False: self.topo_order = find_topo_sort(self.eval_node_list) else: # in inference phase if self.config.use_sparse_pull == True or self.config.cstable_policy is not None: # insert ps_sparse_pull_op self.topo_order = find_topo_sort_inference(self.eval_node_list) # fetch sparse parameter fetch_sparse_parameter_value(self.topo_order, self.config) else: self.topo_order = find_topo_sort(self.eval_node_list) # main structures, nodes' shapes and arrays self.node_to_shape_map = {} self.node_to_arr_map = {} # assisting structures, improve performance self.need_feed_nodes = [] self.param_nodes = [] self.dataloader_nodes = [] self.computing_nodes = [] for node in self.topo_order: if isinstance(node, DataloaderOp) or isinstance(node, GNNDataLoaderOp): self.dataloader_nodes.append(node) elif isinstance(node, PlaceholderOp): if node.shape is None: self.need_feed_nodes.append(node) elif node.trainable: self.param_nodes.append(node) elif not ((self.use_sparse_pull or self.cstable_policy) and isinstance(node, EmbeddingLookUp) and self.config.prefetch): self.computing_nodes.append(node) self.batch_num = set([node.get_batch_num(self.name) for node in self.dataloader_nodes]) assert len(self.batch_num) <= 1, 'Batch num not conform.' self.batch_num = None if len( self.batch_num) == 0 else self.batch_num.pop() self.init_need_allocation = (self.need_feed_nodes == []) and ( self.dataloader_nodes == []) def infer_shape(self, feed_shapes): """Given shapes of feed_dict nodes, infer shape for all nodes in graph. Implementation note: Iteratively calls node.infer_shape to infer shapes. Node shapes stored in self.node_to_shape_map. Parameters ---------- feed_shapes: node->shapes mapping for feed_dict nodes. """ self.node_to_shape_map = {} for node in self.topo_order: if node in feed_shapes: self.node_to_shape_map[node] = tuple(feed_shapes[node]) else: input_shapes = [self.node_to_shape_map[n] for n in node.inputs] cur_shape = node.infer_shape(input_shapes) self.node_to_shape_map[node] = cur_shape if cur_shape is None else tuple( cur_shape) def memory_plan(self): """Allocates ndarray.NDArray for every node except feed_dict nodes. Parameters ---------- """ for node, shape in self.node_to_shape_map.items(): if isinstance(node, PlaceholderOp): if self.config.placeholder_to_arr_map[node] is not None: self.node_to_arr_map[node] = self.config.placeholder_to_arr_map[node] elif node not in self.node_to_arr_map: self.node_to_arr_map[node] = None elif not isinstance(node, DataloaderOp) and not isinstance(node, GNNDataLoaderOp): # add for OptimizerOp and ParameterServerOp if shape is None: self.node_to_arr_map[node] = None continue if isinstance(node, (EmbeddingLookUp_Gradient, DataD2HSparseOp)): self.node_to_arr_map[node] = ndarray.IndexedSlices( dense_shape=shape) continue if isinstance(node, EmbeddingLookUp) and (self.use_sparse_pull or self.cstable_policy) and self.config.prefetch: self.node_to_arr_map[node] = self.param_psval_map[node.inputs[0]] continue if node.on_gpu: if node.inplace: self.node_to_arr_map[node] = ndarray.NDArray(None) elif self.inference and isinstance(node, DropoutOp): self.node_to_arr_map[node] = self.node_to_arr_map[node.inputs[0]] else: self.node_to_arr_map[node] = ndarray.empty( shape, ctx=node.ctx) else: self.node_to_arr_map[node] = ndarray.empty( shape, ctx=node.ctx) def run(self, eval_node_list={}, feed_dict={}, convert_to_numpy_ret_vals=False): """ Parameters ---------- feed_dict: a dictionary of node->np.ndarray supplied by user. convert_to_numpy_ret_vals: whether to convert ret vals to np.array Returns ------- A list of values for nodes in eval_node_list. NDArray or np.ndarray. """ assert len(feed_dict) == len( self.need_feed_nodes) or self.use_p2p, 'Feed dict invalid.' if eval_node_list != {} and eval_node_list != self.eval_node_list: self.update_executor(eval_node_list) feed_shapes = {} need_reallocation = self.init_need_allocation # get feed in values for node, value in feed_dict.items(): if self.use_p2p and node not in self.need_feed_nodes: continue assert node in self.need_feed_nodes, 'Only allow feed in PlaceholderOp with no values, here got %s:%s.' % ( str(type(node)), node.name) local_shape = tuple(value.shape) local_realloc = local_shape != self.node_to_shape_map.get( node, None) need_reallocation = need_reallocation or local_realloc if node.on_cpu: assert isinstance(value, (np.ndarray, spmatrix, ndarray.NDArray)), \ "feed_dict value type not supported" if isinstance(value, np.ndarray): if local_realloc: self.node_to_arr_map[node] = ndarray.empty( local_shape, ctx=node.ctx) self.node_to_arr_map[node][:] = value else: self.node_to_arr_map[node] = value else: if isinstance(value, np.ndarray): if local_realloc: self.node_to_arr_map[node] = ndarray.array( value, ctx=node.ctx) else: self.node_to_arr_map[node][:] = value elif isinstance(value, spmatrix): value = coo_matrix(value) value = ndarray.sparse_array(value.data, (value.row, value.col), shape=local_shape, ctx=node.ctx) self.node_to_arr_map[node] = value elif isinstance(value, ndarray.NDArray): if value.ctx == node.ctx: self.node_to_arr_map[node] = value else: if local_realloc: self.node_to_arr_map[node] = ndarray.empty( local_shape, ctx=node.ctx) else: self.node_to_arr_map[node][:] = value elif isinstance(value, ndarray.ND_Sparse_Array): self.node_to_arr_map[node] = value else: assert False, "feed_dict value type not supported" feed_shapes[node] = local_shape # get dataloader values for node in self.dataloader_nodes: local_shape = node.get_cur_shape(self.name) local_realloc = local_shape != self.node_to_shape_map.get( node, None) need_reallocation = need_reallocation or local_realloc self.node_to_arr_map[node] = node.get_arr(self.name) feed_shapes[node] = local_shape # reallocation, infer shapes and allocate memory if need_reallocation: self.init_need_allocation = False self.infer_shape(feed_shapes) self.memory_plan() # computing for node in self.computing_nodes: if node.on_cpu and isinstance(self.node_to_arr_map[node], ndarray.NDArray): if DNNL_LIB['cpu_ArraySet'] and not isinstance(node, DataD2HOp): cpu_array_set(self.node_to_arr_map[node], 0.0) else: # here we suppose not using DNNL_LIB # self.node_to_arr_map[node][:] = np.zeros(self.node_to_shape_map[node]).astype(np.float32) pass input_vals = [self.node_to_arr_map[n] for n in node.inputs] node_val = self.node_to_arr_map[node] for n in node.inputs: if n.event: n.event.sync() if isinstance(node, (ParameterServerCommunicateOp, ParameterServerSparsePullOp)): # Here we use d2h stream in ps op, since the stream is used for d2h data transfer. # Please take care at this part. node.compute(input_vals, node_val, self.d2h_stream) elif isinstance(node, AllReduceCommunicateOp): node.compute(input_vals, node_val, self.nccl_stream) elif isinstance(node, DataH2DOp): node.compute(input_vals, node_val, self.h2d_stream) elif isinstance(node, (DataD2HOp, DataD2HSparseOp)): node.compute(input_vals, node_val, self.d2h_stream) elif isinstance(node, (PipelineSendOp, PipelineReceiveOp)): node.compute(input_vals, node_val) elif isinstance(node, (DropoutOp, Batch_NormalizationOp, Layer_NormalizationOp)): node.compute(input_vals, node_val, self.comp_stream, inference=self.inference) if isinstance(node.event, Event): # for d2h op / eval nodes / nodes before [allreduce or ps nodes or pipelinesend nodes] node.event.record(self.comp_stream) else: node.compute(input_vals, node_val, self.comp_stream) if isinstance(node.event, Event): # for d2h op / eval nodes / nodes before [allreduce or ps nodes or pipelinesend nodes] node.event.record(self.comp_stream) for n in self.eval_node_list: # every node in eval_node_list should have an event (except dataloader/optimizer...) if n.event: n.event.sync() # get results results = [self.node_to_arr_map[n] for n in self.eval_node_list] if convert_to_numpy_ret_vals: for i in range(len(results)): if results[i] is not None: results[i] = results[i].asnumpy() # remap to original order in model parallel if self.use_p2p: new_results = [None for _ in self.global_eval_nodes] for i, j in enumerate(self.run_results_indices): new_results[j] = results[i] results = new_results return results def gradients(output_node, node_list, insert_grad=None): """Take gradient of output node with respect to each node in node_list. Parameters ---------- output_node: output node that we are taking derivative of. node_list: list of nodes that we are taking derivative wrt. insert_grad: used to assign gradient to output_node in model parallel. Returns ------- A list of gradient values, one for each node in node_list respectively. """ if isinstance(output_node, list): node_to_output_grads_list = { output_node[i]: [OnesLike.oneslike_op(output_node[i])] if insert_grad is None else [insert_grad[i]] for i in range(len(output_node)) } else: node_to_output_grads_list = { output_node: [OnesLike.oneslike_op(output_node)] if insert_grad is None else [ insert_grad] } output_node = [output_node] node_to_output_grad = {} # Traverse forward graph in reverse topological order reverse_topo_order = reversed(find_topo_sort(output_node)) for node in reverse_topo_order: # here the ctx for embedding lookup is a workaround # TODO: when implement PS strategy for context semantics, modify here if isinstance(node, EmbeddingLookUp): output_grad = sum_node_list( node_to_output_grads_list[node], node_to_output_grads_list[node][0].raw_ctx) else: output_grad = sum_node_list( node_to_output_grads_list[node], node.raw_ctx) if output_grad is None: for n in node.inputs: if n not in node_to_output_grads_list: node_to_output_grads_list[n] = [] continue node_to_output_grad[node] = output_grad input_grads_list = node.gradient(output_grad) for i in range(len(node.inputs)): if node.inputs[i] not in node_to_output_grads_list: node_to_output_grads_list[node.inputs[i]] = [] # Calculate partial adjoint for input nodes. node_to_output_grads_list[node.inputs[i]].append( input_grads_list[i]) grad_node_list = [node_to_output_grad[node] for node in node_list] return grad_node_list ################## # Helper Methods # ################## def topo_sort_with_hook(node_list, config): visited = set() for node in node_list: topo_sort_dfs_with_hook(node, visited, config) def topo_sort_dfs_with_hook(node, visited, config): if node in visited: return visited.add(node) node.backward_hook(config) # move param from node to config if isinstance(node, PlaceholderOp): config.placeholder_to_arr_map[node] = node.tensor_value node.tensor_value = None for n in node.inputs: topo_sort_dfs_with_hook(n, visited, config) node.forward_hook(config) def find_topo_sort(node_list): """Given a list of nodes, return a topo ordering of nodes ending in them. A simple algorithm is to do a post-order DFS traversal on the given nodes, going backwards based on input edges. Since a node is added to the ordering after all its predecessors are traversed due to post-order DFS, we get a topological sort. """ visited = set() topo_order = [] for node in node_list: topo_sort_dfs(node, visited, topo_order) return topo_order def topo_sort_dfs(node, visited, topo_order): """Post-order DFS""" if node in visited: return visited.add(node) for n in node.inputs: topo_sort_dfs(n, visited, topo_order) topo_order.append(node) def find_topo_sort_inference(node_list): topo_order = find_topo_sort(node_list) embedding_list = list() embedding_outputs = dict() embedding_cnt = dict() for node in topo_order: if isinstance(node, EmbeddingLookUp): embedding_outputs[node] = list() embedding_cnt[node] = 0 embedding_list.append(node) else: for input_node in node.inputs: if isinstance(input_node, EmbeddingLookUp): embedding_outputs[input_node].append(node) embedding_cnt[input_node] += 1 topo_order_inference = list() for node in topo_order: topo_order_inference.append(node) for embedding in embedding_list: if node in embedding_outputs[embedding]: embedding_cnt[embedding] -= 1 if embedding_cnt[embedding] == 0: topo_order_inference.append(parameterServerSparsePull_op( embedding, embedding_outputs[embedding])) embedding_list.remove(embedding) return topo_order_inference def fetch_sparse_parameter_value(node_list, config): for node in node_list: if isinstance(node, ParameterServerSparsePullOp): node.forward_hook(config) def fetch_dense_parameter_value(node_list, config): assert config.comm_mode in ('PS', 'Hybrid') topo_order = find_topo_sort(node_list) val_list = [] # get var list for node in topo_order: if isinstance(node, PlaceholderOp) and node.trainable: val_list.append(node) for node in val_list: if config.use_sparse_pull and node.is_embed: continue else: pull_val = ndarray.empty(node.shape, ctx=ndarray.cpu(0)) config.ps_comm.Pull(node.id, pull_val.handle) config.infer_ps_map[node] = pull_val config.placeholder_to_arr_map[node] = pull_val node.event.update() def sum_node_list(node_list, ctx): """Custom sum func to avoid creating redundant nodes in Python sum func.""" node_list = [n for n in node_list if n is not None] if node_list == []: return None sum_node = node_list[0] for n in node_list[1:]: sum_node = add_op(sum_node, n, ctx=ctx) return sum_node