|
|
|
@@ -304,15 +304,18 @@ int64_t WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const : |
|
|
|
auto &kvs = lookup_results_[ts]; |
|
|
|
mutex_.unlock(); |
|
|
|
|
|
|
|
if (lookup_ids.empty()) { |
|
|
|
MS_LOG(EXCEPTION) << "Lookup id is empty."; |
|
|
|
} |
|
|
|
int64_t single_id_len = SizeToLong(lookup_result->size() / lookup_ids.size()); |
|
|
|
std::unordered_map<Key, std::shared_ptr<std::pair<T *, int64_t>>> id_addr_map; |
|
|
|
for (const auto &s : kvs) { |
|
|
|
int64_t offset = 0; |
|
|
|
int64_t len = s.vals.size() / s.keys.size(); |
|
|
|
for (size_t i = 0; i < s.keys.size(); i++) { |
|
|
|
const Key &key = s.keys[i]; |
|
|
|
T *addr = s.vals.data() + offset; |
|
|
|
offset += len; |
|
|
|
id_addr_map[key] = std::make_shared<std::pair<T *, int64_t>>(std::make_pair(addr, len)); |
|
|
|
offset += single_id_len; |
|
|
|
id_addr_map[key] = std::make_shared<std::pair<T *, int64_t>>(std::make_pair(addr, single_id_len)); |
|
|
|
MS_EXCEPTION_IF_NULL(id_addr_map[key]); |
|
|
|
} |
|
|
|
} |
|
|
|
@@ -325,8 +328,12 @@ int64_t WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const : |
|
|
|
void *dst_data = nullptr; |
|
|
|
void *src_data = nullptr; |
|
|
|
for (size_t i = 0; i < lookup_ids.size(); i++) { |
|
|
|
if (id_addr_map.count(lookup_ids[i]) == 0) { |
|
|
|
offset += single_id_len; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto &pair = id_addr_map[static_cast<Key>(lookup_ids[i])]; |
|
|
|
int64_t size = pair->second * sizeof(T); |
|
|
|
int64_t size = single_id_len * sizeof(T); |
|
|
|
dst_size = size; |
|
|
|
src_size = size; |
|
|
|
dst_data = result_addr + offset; |
|
|
|
@@ -338,7 +345,7 @@ int64_t WorkerProxy<T>::AddLookupCB(const ::ps::SArray<::ps::Key> &keys, const : |
|
|
|
MS_LOG(EXCEPTION) << "memcpy_s error, errorno(" << ret << ")"; |
|
|
|
return; |
|
|
|
} |
|
|
|
offset += pair->second; |
|
|
|
offset += single_id_len; |
|
|
|
} |
|
|
|
|
|
|
|
mutex_.lock(); |
|
|
|
@@ -406,6 +413,8 @@ void WorkerProxy<T>::LookupIdSlicer(int64_t timestamp, const ::ps::KVPairs<T> &s |
|
|
|
|
|
|
|
for (size_t j = 0; j < id_size; j++) { |
|
|
|
auto lookup_id = static_cast<uint64_t>(lookup_ids[j]); |
|
|
|
// If lookup_id is out of range, like negative number, unique_ids will not contain it. |
|
|
|
// Servers always get lookup_ids in its embedding table range. |
|
|
|
if (lookup_id >= begin && lookup_id <= end) { |
|
|
|
unique_ids.insert(lookup_id); |
|
|
|
} |
|
|
|
|