|
- from __future__ import absolute_import
- from .Node import Op
- import math
- from .. import ndarray
- import numpy as np
- from ctypes import *
-
-
- def row_num(node_count, rank, size):
- n_per_proc = math.ceil(float(node_count) / size)
- if (node_count % size == 0):
- return int(node_count / size)
- if (rank < size - 1):
- return int(n_per_proc)
- else:
- return int(node_count % n_per_proc)
-
-
- def broad_func(node_count, adj_matrix, inputs, rank, size, replication, row_groups, col_groups, ctx, comm=None, stream_handle=None):
- assert size % (replication ** 2) == 0
-
- n_per_proc = math.ceil(float(node_count) / (size // replication))
- proc_node_count = row_num(
- node_count, rank//replication, size // replication)
-
- z_loc = ndarray.empty((proc_node_count, inputs.shape[1]), ctx=ctx)
- tmp = ndarray.empty((proc_node_count, inputs.shape[1]), ctx=ctx)
- inputs_recv = ndarray.empty((int(n_per_proc), inputs.shape[1]), ctx=ctx)
-
- rank_c = rank // replication
- rank_col = rank % replication
-
- stages = size // (replication ** 2)
- node_count_col = stages * n_per_proc
- if rank_col == replication - 1:
- stages = (size // replication) - (replication - 1) * stages
- node_count_col = node_count - (replication - 1) * node_count_col
-
- start_pos = list(range(0, int(node_count_col), int(n_per_proc)))
- end_pos = start_pos[1:]+[int(node_count_col)]
-
- for i in range(stages):
- q = (rank_col * (size // (replication ** 2)) + i) * \
- replication + rank_col
- q_c = q // replication
-
- if q_c == size // replication - 1:
- inputs_recv = ndarray.empty((row_num(
- node_count, size//replication - 1, size//replication), inputs.shape[1]), ctx=ctx)
- if q == rank:
- inputs.copyto(inputs_recv)
-
- from ..communicator.mpi_nccl_comm import ncclDataType_t, ncclRedOp_t
- if replication > 1:
- col_groups[rank_col].dlarrayBroadcast(
- inputs_recv, inputs_recv, ncclDataType_t.ncclFloat32, q)
- else:
- comm.dlarrayBroadcast(inputs_recv, inputs_recv,
- ncclDataType_t.ncclFloat32, q)
-
- from ..gpu_links import CuSparse_Csrmm, matrix_elementwise_add
- CuSparse_Csrmm(adj_matrix, False, inputs_recv, False, tmp,
- stream=stream_handle, start_pos=int(start_pos[i]), end_pos=int(end_pos[i]))
- matrix_elementwise_add(z_loc, tmp, z_loc, stream_handle)
-
- if replication > 1:
- row_groups[rank_c].dlarrayNcclAllReduce(
- z_loc, z_loc, ncclDataType_t.ncclFloat32, reduceop=ncclRedOp_t.ncclSum)
-
- return z_loc
-
-
- class DistGCN_15dOp(Op):
- def __init__(self, node_A, node_B, node_C, node_Count_Self, node_Count_All, size, replication, device_id, comm, comm_groups=[None, None], need_W=True):
- super().__init__(DistGCN_15dOp, [
- node_A, node_B, node_C], ctx=ndarray.gpu(device_id))
- self.need_W = need_W
- self.node_Count_Self = node_Count_Self
- self.node_Count_All = node_Count_All
- self.replication = replication
- self.size = size
- self.comm = comm
- self.comm_groups = comm_groups
- self.device_id = device_id
-
- def compute(self, input_vals, output_val, stream_handle=None):
- adj_matrix = input_vals[0]
- inputs_H = input_vals[1]
- weight = input_vals[2]
- node_count = self.node_Count_All
- comm = self.comm
- rank = comm.localRank.value
- ctx = ndarray.gpu(self.device_id)
-
- if weight.shape[1] < inputs_H.shape[1]:
- HW = ndarray.empty((inputs_H.shape[0], weight.shape[1]), ctx=ctx)
- if (self.need_W == True):
- from ..gpu_links import matrix_multiply
- matrix_multiply(inputs_H, False, weight,
- False, HW, stream_handle)
- else:
- HW = inputs_H
- z = broad_func(node_count, adj_matrix, HW, rank, self.size, self.replication,
- row_groups=self.comm_groups[0], col_groups=self.comm_groups[1], ctx=ctx, comm=comm, stream_handle=stream_handle)
- z.copyto(output_val)
- else:
- AH = broad_func(node_count, adj_matrix, inputs_H, rank, self.size, self.replication,
- row_groups=self.comm_groups[0], col_groups=self.comm_groups[1], ctx=ctx, comm=comm, stream_handle=stream_handle)
- z = ndarray.empty((AH.shape[0], weight.shape[1]), ctx=ctx)
- if (self.need_W == True):
- from ..gpu_links import matrix_multiply
- matrix_multiply(AH, False, weight, False, z, stream_handle)
- else:
- z = AH
- z.copyto(output_val)
-
- def gradient(self, output_grad):
- adj_matrix = self.inputs[0]
- inputs_H = self.inputs[1]
- weight = self.inputs[2]
- node_Count_Self = self.node_Count_Self
- node_Count_All = self.node_Count_All
- comm = self.comm
- rank = comm.localRank.value
- ag = distgcn_15d_op(adj_matrix, output_grad, weight, node_Count_Self, node_Count_All,
- self.size, self.replication, self.device_id, comm, self.comm_groups, need_W=False)
-
- from . import matmul_op
- grad_A = None
- grad_H = matmul_op(ag, weight, trans_B=True)
- grad_weight = matmul_op(inputs_H, ag, trans_A=True)
- from . import groupallreduceCommunicate_op
- if self.replication > 1:
- weight_groups = self.comm_groups[1]
- if len(self.comm_groups) == 3:
- weight_groups = self.comm_groups[2]
- grad_W = groupallreduceCommunicate_op(
- grad_weight, weight_groups[rank % self.replication])
- else:
- grad_W = groupallreduceCommunicate_op(grad_weight, comm)
- return [grad_A, grad_H, grad_W]
-
- def infer_shape(self, input_shapes):
- assert len(input_shapes) == 3
- H = input_shapes[1]
- W = input_shapes[2]
- shape_H = H[1]
- shape_W = W[1]
- if (self.need_W == True):
- return (self.node_Count_Self, shape_W)
- else:
- return (self.node_Count_Self, shape_H)
-
-
- def distgcn_15d_op(node_A, node_B, node_C, node_Count_Self, node_Count_All, size, replication, device_id, comm, comm_groups=[None, None], need_W=True):
- return DistGCN_15dOp(node_A, node_B, node_C, node_Count_Self, node_Count_All, size, replication, device_id, comm, comm_groups, need_W=need_W)
|