|
|
|
@@ -44,7 +44,7 @@ bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { |
|
|
|
auto primitive = AnfAlgo::GetCNodePrimitive(anf_node); |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
label_list_ = GetValue<std::vector<uint32_t>>(primitive->GetAttr(kAttrLabelSwitchList)); |
|
|
|
label_size_ = label_list_.size(); |
|
|
|
label_size_ = SizeToUint(label_list_.size()); |
|
|
|
MS_LOG(INFO) << "LabelSwitchKernel get attr label size:" << label_size_; |
|
|
|
for (auto label : label_list_) { |
|
|
|
MS_LOG(INFO) << "label: " << label; |
|
|
|
@@ -52,16 +52,15 @@ bool LabelSwitchKernel::Init(const AnfNodePtr &anf_node) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool LabelSwitchKernel::Launch(const std::vector<AddressPtr> & /*inputs*/, |
|
|
|
const std::vector<AddressPtr> & /*workspace*/, |
|
|
|
const std::vector<AddressPtr> & /*outputs*/, void * /*stream_ptr*/) { |
|
|
|
bool LabelSwitchKernel::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, |
|
|
|
const std::vector<AddressPtr> &, void *) { |
|
|
|
MS_LOG(INFO) << "LabelSwitchKernel launch"; |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<TaskInfoPtr> LabelSwitchKernel::GenTask(const std::vector<AddressPtr> &inputs, |
|
|
|
const std::vector<AddressPtr> &workspace, |
|
|
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) { |
|
|
|
const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, |
|
|
|
uint32_t stream_id) { |
|
|
|
MS_LOG(INFO) << "LabelSwitchKernel GenTask label size:" << label_size_ << ", stream id:" << stream_id; |
|
|
|
std::vector<TaskInfoPtr> task_info_list; |
|
|
|
cond_ = inputs[0]->addr; |
|
|
|
|