You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stream_switch.cc 3.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/mng/stream_switch.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "runtime/stream.h"
  20. #include "framework/ge_runtime/task_info.h"
  21. #include "session/anf_runtime_algorithm.h"
  22. #include "common/utils.h"
  23. using ge::model_runner::StreamSwitchTaskInfo;
  24. using StreamSwitchTaskInfoPtr = std::shared_ptr<StreamSwitchTaskInfo>;
  25. namespace mindspore {
  26. namespace kernel {
  27. StreamSwitchKernel::StreamSwitchKernel() {
  28. cond_ = RT_EQUAL;
  29. true_stream_index_ = 0;
  30. data_type_ = RT_SWITCH_INT32;
  31. }
  32. StreamSwitchKernel::~StreamSwitchKernel() {}
  33. bool StreamSwitchKernel::Init(const AnfNodePtr &anf_node) {
  34. MS_EXCEPTION_IF_NULL(anf_node);
  35. MS_LOG(INFO) << "stream switch op init start";
  36. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  37. MS_EXCEPTION_IF_NULL(primitive);
  38. cond_ = tagRtCondition(GetValue<int>(primitive->GetAttr(kAttrSwitchCondition)));
  39. true_stream_index_ = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
  40. data_type_ = tagRtSwitchDataType(GetValue<int>(primitive->GetAttr(kAttrDataType)));
  41. MS_LOG(INFO) << "cond_:" << static_cast<int>(cond_) << ", true_stream_index_:" << true_stream_index_
  42. << ", data_type_:" << static_cast<int>(data_type_);
  43. return true;
  44. }
  45. bool StreamSwitchKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
  46. const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) {
  47. MS_LOG(INFO) << "stream switch op launch start";
  48. if (inputs.size() != 2) {
  49. MS_LOG(ERROR) << "Stream switch inputs size is " << inputs.size() << ", only support 2";
  50. }
  51. void *loop_cnt = inputs[0]->addr;
  52. void *ites_per_loop = inputs[1]->addr;
  53. auto stream = reinterpret_cast<rtStream_t>(stream_ptr);
  54. rtStream_t true_stream_ = kernel::TaskStream::GetInstance()->gen_stream_list()[true_stream_index_];
  55. rtError_t status = rtStreamSwitchEx(loop_cnt, cond_, ites_per_loop, true_stream_, stream, data_type_);
  56. if (status != RT_ERROR_NONE) {
  57. MS_LOG(ERROR) << "Stream switch failed!";
  58. return false;
  59. }
  60. return true;
  61. }
  62. std::vector<TaskInfoPtr> StreamSwitchKernel::GenTask(const std::vector<AddressPtr> &inputs,
  63. const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
  64. uint32_t stream_id) {
  65. MS_LOG(INFO) << "StreamSwitchKernel GenTask start";
  66. if (inputs.size() != 2) {
  67. MS_LOG(ERROR) << "stream switch inputs size is " << inputs.size() << ", is not two";
  68. }
  69. MS_EXCEPTION_IF_NULL(inputs[0]);
  70. MS_EXCEPTION_IF_NULL(inputs[1]);
  71. auto loop_cnt = inputs[0]->addr;
  72. auto ites_per_loop = inputs[1]->addr;
  73. MS_LOG(INFO) << "cond_:" << static_cast<int>(cond_) << ", true_stream_index_:" << true_stream_index_
  74. << ", stream_id:" << stream_id;
  75. std::shared_ptr<StreamSwitchTaskInfo> task_info_ptr =
  76. std::make_shared<StreamSwitchTaskInfo>(stream_id, true_stream_index_, loop_cnt, ites_per_loop, cond_, data_type_);
  77. MS_EXCEPTION_IF_NULL(task_info_ptr);
  78. return {task_info_ptr};
  79. }
  80. } // namespace kernel
  81. } // namespace mindspore