|
- from hetu.gpu_ops.executor import path_to_lib
- from hetu import ndarray
- from hetu import get_worker_communicate
-
- import numpy as np
- import sys
- import os
- from functools import partial
-
- """
- CacheSparseTable:
- length, width: the length and width of the whole embedding table
- limit: the max number of embedding lines stored in cache
- node_id: the unique node_id in the model
- policy: cache policy, LRU or LFU
- """
-
-
- class CacheSparseTable:
- def __init__(self, limit, length, width, node_id, policy="LRU", bound=100):
- # make sure we open libps.so first
- comm = get_worker_communicate()
- sys.path.append(os.path.dirname(__file__)+"/../../build/lib")
- import hetu_cache
- policy = policy.lower()
- if policy == "lru":
- self.cache = hetu_cache.LRUCache(limit, length, width, node_id)
- elif policy == "lfu":
- self.cache = hetu_cache.LFUCache(limit, length, width, node_id)
- elif policy == "lfuopt":
- self.cache = hetu_cache.LFUOptCache(limit, length, width, node_id)
- else:
- raise NotImplementedError(policy)
- self.cache.pull_bound = bound
- self.cache.push_bound = bound
- comm.BarrierWorker()
-
- """
- embedding_lookup:
- keys: a list of keys to lookup
- dest: target memory space to write to
- sync: async call of sync call
- if async, a wait_t is returned, use wait.wait() to wait until it finish.
- if async, must make sure keys and dest are alive throughout the call
- """
-
- def embedding_lookup(self, keys, dest, sync=False):
- wait = None
- if type(keys) is np.ndarray and type(dest) is np.ndarray:
- assert dest.shape == (keys.size, self.width)
- assert keys.dtype == np.uint64
- assert dest.dtype == np.float32
- wait = self.cache.embedding_lookup(keys, dest)
- elif type(keys) is ndarray.NDArray and type(dest) is ndarray.NDArray:
- assert dest.shape == (*keys.shape, self.width)
- assert not ndarray.is_gpu_ctx(keys.ctx)
- assert not ndarray.is_gpu_ctx(dest.ctx)
- wait = self.cache.embedding_lookup_raw(
- keys.handle.contents.data, dest.handle.contents.data, np.prod(keys.shape))
- else:
- raise TypeError
- if sync:
- wait.wait()
- else:
- return wait
- """
- embedding_lookup:
- keys: a list of keys to update
- grads: gradients to send
- sync: async call of sync call
- if async, a wait_t is returned, use wait.wait() to wait until it finish.
- if async, must make sure keys and dest are alive throughout the call
- """
-
- def embedding_update(self, keys, grads, sync=False):
- wait = None
- if type(keys) is np.ndarray and type(grads) is np.ndarray:
- assert grads.shape == (keys.size, self.width)
- assert keys.dtype == np.uint64
- assert grads.dtype == np.float32
- wait = self.cache.embedding_update(keys, grads)
- elif type(keys) is ndarray.NDArray and type(grads) is ndarray.NDArray:
- assert grads.shape == (*keys.shape, self.width)
- assert not ndarray.is_gpu_ctx(keys.ctx)
- assert not ndarray.is_gpu_ctx(grads.ctx)
- wait = self.cache.embedding_update_raw(
- keys.handle.contents.data, grads.handle.contents.data, np.prod(keys.shape))
- else:
- raise TypeError
- if sync:
- wait.wait()
- else:
- return wait
-
- def embedding_push_pull(self, pullkeys, dest, pushkeys, grads, sync=False):
- wait = None
- if type(pullkeys) is ndarray.NDArray and type(dest) is ndarray.NDArray and \
- type(pushkeys) is ndarray.NDArray and type(grads) is ndarray.NDArray:
- assert grads.shape == (*pushkeys.shape, self.width)
- assert dest.shape == (*pullkeys.shape, self.width)
- assert not ndarray.is_gpu_ctx(pullkeys.ctx)
- assert not ndarray.is_gpu_ctx(pushkeys.ctx)
- assert not ndarray.is_gpu_ctx(grads.ctx)
- assert not ndarray.is_gpu_ctx(dest.ctx)
- wait = self.cache.embedding_push_pull_raw(
- pullkeys.handle.contents.data, dest.handle.contents.data, np.prod(
- pullkeys.shape),
- pushkeys.handle.contents.data, grads.handle.contents.data, np.prod(
- pushkeys.shape)
- )
- else:
- raise TypeError
- if sync:
- wait.wait()
- else:
- return wait
-
- @property
- def width(self):
- return self.cache.width
-
- @property
- def limit(self):
- return self.cache.limit
-
- def perf_enabled(self, enable=True):
- self.cache.perf_enabled = enable
-
- @property
- def perf(self):
- # perf data example [item1, item2...]
- # item = "type": pull_or_push, "is_full": is_cache_full, "num_all", num_of_key
- # "num_unique": num_of_unique_key, "num_miss": num_of_missed_unique_key,
- # "num_evict": num_push_of_eviction, "num_transfered"(if push): miss+outofpushbound+evict
- # "num_transfered"(if pull): miss+outofpullbound, "time": last_time_in_ms
- return self.cache.perf
-
- # if bypass, directly pull and push the server
- def bypass(self):
- self.cache.bypass()
-
- def undobypass(self):
- self.cache.undo_bypass()
-
- def __repr__(self):
- return self.cache.__repr__()
-
- # the following calls are single key call
- # for debugging
- def lookup(self, key):
- return self.cache.lookup(key)
-
- def count(self, key):
- return self.cache.count(key)
-
- def insert(self, key, embedding):
- return self.cache.insert(key, embedding)
-
- def keys(self):
- return self.cache.keys()
-
- # PerfHelperFunction
-
- # miss rate for pull
- def overall_miss_rate(self, include_cold_start=False):
- if not include_cold_start:
- perf = list(filter(lambda x: x["is_full"], self.perf))
- else:
- perf = self.perf
- if not perf:
- return -1
- pull_perf = list(filter(lambda x: x["type"] == "Pull", perf))
- num_all = [x["num_unique"] for x in pull_perf]
- num_miss = [x["num_miss"] for x in pull_perf]
- return np.sum(num_miss) / np.sum(num_all)
-
- # data rate compared with vanilla sparse pull (ignore cost for idx&version)
- def overall_data_rate(self, include_cold_start=False):
- if not include_cold_start:
- perf = list(filter(lambda x: x["is_full"], self.perf))
- else:
- perf = self.perf
- if not perf:
- return -1
- num_all = [x["num_all"] for x in perf]
- num_miss = [x["num_transfered"] for x in perf]
- return np.sum(num_miss) / np.sum(num_all)
-
- def debug_keys(self):
- comm = get_worker_communicate()
- nrank = comm.nrank()
- form = "w" if comm.rank() == 0 else "a"
- for i in range(nrank):
- if i == comm.rank():
- with open("_keys.log".format(comm.rank()), form) as f:
- print(*self.keys(), file=f, flush=True)
- comm.BarrierWorker()
-
- if comm.rank() != 0:
- return
- keys = []
- with open("_keys.log".format(comm.rank()), "r") as f:
- for i in range(nrank):
- keys.append(set(map(int, f.readline().split())))
- rt = np.zeros([nrank, nrank])
- for i in range(nrank):
- for j in range(nrank):
- if not keys[i]:
- continue
- rt[i][j] = len(keys[i].intersection(keys[j])) / len(keys[i])
- return rt
|