|
|
|
@@ -211,22 +211,22 @@ Status GraphDataImpl::GetSampledNeighbors(const std::vector<NodeIdType> &node_li |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, |
|
|
|
const std::unordered_set<NodeIdType> &exclude_data, int32_t samples_num, |
|
|
|
std::vector<NodeIdType> *out_samples) { |
|
|
|
Status GraphDataImpl::NegativeSample(const std::vector<NodeIdType> &data, const std::vector<NodeIdType> shuffled_ids, |
|
|
|
size_t *start_index, const std::unordered_set<NodeIdType> &exclude_data, |
|
|
|
int32_t samples_num, std::vector<NodeIdType> *out_samples) { |
|
|
|
CHECK_FAIL_RETURN_UNEXPECTED(!data.empty(), "Input data is empty."); |
|
|
|
std::vector<NodeIdType> shuffled_id(data.size()); |
|
|
|
std::iota(shuffled_id.begin(), shuffled_id.end(), 0); |
|
|
|
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); |
|
|
|
for (const auto &index : shuffled_id) { |
|
|
|
if (exclude_data.find(data[index]) != exclude_data.end()) { |
|
|
|
size_t index = *start_index; |
|
|
|
for (size_t i = index; i < shuffled_ids.size(); ++i) { |
|
|
|
++index; |
|
|
|
if (exclude_data.find(data[shuffled_ids[i]]) != exclude_data.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
out_samples->emplace_back(data[index]); |
|
|
|
out_samples->emplace_back(data[shuffled_ids[i]]); |
|
|
|
if (out_samples->size() >= samples_num) { |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
*start_index = index; |
|
|
|
return Status::OK(); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -236,6 +236,13 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node |
|
|
|
RETURN_IF_NOT_OK(CheckSamplesNum(samples_num)); |
|
|
|
RETURN_IF_NOT_OK(CheckNeighborType(neg_neighbor_type)); |
|
|
|
|
|
|
|
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type]; |
|
|
|
std::vector<NodeIdType> shuffled_id(all_nodes.size()); |
|
|
|
std::iota(shuffled_id.begin(), shuffled_id.end(), 0); |
|
|
|
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); |
|
|
|
size_t start_index = 0; |
|
|
|
bool need_shuffle = false; |
|
|
|
|
|
|
|
std::vector<std::vector<NodeIdType>> neg_neighbors_vec; |
|
|
|
neg_neighbors_vec.resize(node_list.size()); |
|
|
|
for (size_t node_idx = 0; node_idx < node_list.size(); ++node_idx) { |
|
|
|
@@ -247,12 +254,15 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node |
|
|
|
std::transform(neighbors.begin(), neighbors.end(), |
|
|
|
std::insert_iterator<std::unordered_set<NodeIdType>>(exclude_nodes, exclude_nodes.begin()), |
|
|
|
[](const NodeIdType node) { return node; }); |
|
|
|
const std::vector<NodeIdType> &all_nodes = node_type_map_[neg_neighbor_type]; |
|
|
|
neg_neighbors_vec[node_idx].emplace_back(node->id()); |
|
|
|
if (all_nodes.size() > exclude_nodes.size()) { |
|
|
|
while (neg_neighbors_vec[node_idx].size() < samples_num + 1) { |
|
|
|
RETURN_IF_NOT_OK(NegativeSample(all_nodes, exclude_nodes, samples_num - neg_neighbors_vec[node_idx].size(), |
|
|
|
RETURN_IF_NOT_OK(NegativeSample(all_nodes, shuffled_id, &start_index, exclude_nodes, samples_num + 1, |
|
|
|
&neg_neighbors_vec[node_idx])); |
|
|
|
if (start_index >= shuffled_id.size()) { |
|
|
|
start_index = start_index % shuffled_id.size(); |
|
|
|
need_shuffle = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(DEBUG) << "There are no negative neighbors. node_id:" << node->id() |
|
|
|
@@ -262,6 +272,11 @@ Status GraphDataImpl::GetNegSampledNeighbors(const std::vector<NodeIdType> &node |
|
|
|
neg_neighbors_vec[node_idx].emplace_back(kDefaultNodeId); |
|
|
|
} |
|
|
|
} |
|
|
|
if (need_shuffle) { |
|
|
|
std::shuffle(shuffled_id.begin(), shuffled_id.end(), rnd_); |
|
|
|
start_index = 0; |
|
|
|
need_shuffle = false; |
|
|
|
} |
|
|
|
} |
|
|
|
RETURN_IF_NOT_OK(CreateTensorByVector<NodeIdType>(neg_neighbors_vec, DataType(DataType::DE_INT32), out)); |
|
|
|
return Status::OK(); |
|
|
|
|