|
|
|
@@ -45,8 +45,14 @@ DeviceQueueOp::DeviceQueueOp(std::string channel_name, DeviceType device_type, i |
|
|
|
total_batch_(total_batch), |
|
|
|
create_data_info_queue_(create_data_info_queue) { |
|
|
|
#ifdef ENABLE_GPUQUE |
|
|
|
// Get the total device num of current machine |
|
|
|
int32_t device_count = 0; |
|
|
|
cudaGetDeviceCount(&device_count); |
|
|
|
std::shared_ptr<ConfigManager> cfg = GlobalContext::config_manager(); |
|
|
|
rank_id_ = cfg->rank_id(); // Get the current rank_id |
|
|
|
if (device_count > 0) { |
|
|
|
rank_id_ = rank_id_ % device_count; |
|
|
|
} |
|
|
|
// Be careful when try to modified these num_workers_ and queue_capacity_, |
|
|
|
// and we suggest num_workers_ * queue_capacity_ not greater than 16, because |
|
|
|
// one worker one circular_pool with 1G pin memory, so num_workers_ * queue_capacity_ |
|
|
|
|