|
|
|
@@ -67,9 +67,7 @@ std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr |
|
|
|
MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; |
|
|
|
std::vector<TaskInfoPtr> task_info_list; |
|
|
|
cond_ = inputs[0]->addr; |
|
|
|
// todo: need update ge task info define |
|
|
|
auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, 0); |
|
|
|
// auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, label_list_, cond_); |
|
|
|
auto task_info_ptr = std::make_shared<LabelSwitchTaskInfo>(stream_id, label_size_, label_list_, cond_); |
|
|
|
MS_EXCEPTION_IF_NULL(task_info_ptr); |
|
|
|
task_info_list.emplace_back(task_info_ptr); |
|
|
|
return task_info_list; |
|
|
|
|