You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stream_switch.cc 4.2 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "kernel/rts/stream_switch.h"
  17. #include <memory>
  18. #include <vector>
  19. #include "runtime/stream.h"
  20. #include "framework/ge_runtime/task_info.h"
  21. #include "session/anf_runtime_algorithm.h"
  22. #include "common/utils.h"
  23. using ge::model_runner::StreamSwitchTaskInfo;
  24. using StreamSwitchTaskInfoPtr = std::shared_ptr<StreamSwitchTaskInfo>;
  25. namespace mindspore {
  26. namespace kernel {
  27. StreamSwitchKernel::StreamSwitchKernel() {
  28. cond_ = RT_EQUAL;
  29. true_stream_index_ = 0;
  30. data_type_ = RT_SWITCH_INT32;
  31. }
  32. StreamSwitchKernel::~StreamSwitchKernel() {}
  33. bool StreamSwitchKernel::Init(const AnfNodePtr &anf_node) {
  34. MS_EXCEPTION_IF_NULL(anf_node);
  35. MS_LOG(INFO) << "stream switch op init start";
  36. auto primitive = AnfAlgo::GetCNodePrimitive(anf_node);
  37. MS_EXCEPTION_IF_NULL(primitive);
  38. if (!AnfAlgo::HasNodeAttr(kAttrSwitchCondition, anf_node->cast<CNodePtr>())) {
  39. MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrSwitchCondition";
  40. }
  41. cond_ = tagRtCondition(GetValue<int>(primitive->GetAttr(kAttrSwitchCondition)));
  42. if (!AnfAlgo::HasNodeAttr(kAttrTrueBranchStream, anf_node->cast<CNodePtr>())) {
  43. MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrTrueBranchStream";
  44. }
  45. true_stream_index_ = GetValue<uint32_t>(primitive->GetAttr(kAttrTrueBranchStream));
  46. if (!AnfAlgo::HasNodeAttr(kAttrDataType, anf_node->cast<CNodePtr>())) {
  47. MS_LOG(EXCEPTION) << "StreamSwitchKernel has no attr kAttrDataType";
  48. }
  49. data_type_ = tagRtSwitchDataType(GetValue<int>(primitive->GetAttr(kAttrDataType)));
  50. MS_LOG(INFO) << "cond_:" << static_cast<int>(cond_) << ", true_stream_index_:" << true_stream_index_
  51. << ", data_type_:" << static_cast<int>(data_type_);
  52. return true;
  53. }
  54. bool StreamSwitchKernel::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
  55. const std::vector<AddressPtr> &outputs, void *stream_ptr) {
  56. MS_LOG(INFO) << "stream switch op launch start";
  57. if (inputs.size() != 2) {
  58. MS_LOG(EXCEPTION) << "Stream switch inputs size is " << inputs.size() << ", only support 2";
  59. }
  60. void *loop_cnt = inputs[0]->addr;
  61. void *ites_per_loop = inputs[1]->addr;
  62. rtStream_t true_stream_ = kernel::TaskStream::GetInstance()->gen_stream_list()[true_stream_index_];
  63. rtError_t status = rtStreamSwitchEx(loop_cnt, cond_, ites_per_loop, true_stream_, stream_ptr, data_type_);
  64. if (status != RT_ERROR_NONE) {
  65. MS_LOG(ERROR) << "Stream switch failed!";
  66. return false;
  67. }
  68. return true;
  69. }
  70. std::vector<TaskInfoPtr> StreamSwitchKernel::GenTask(const std::vector<AddressPtr> &inputs,
  71. const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
  72. uint32_t stream_id) {
  73. MS_LOG(INFO) << "StreamSwitchKernel GenTask start";
  74. if (inputs.size() != 2) {
  75. MS_LOG(EXCEPTION) << "stream switch inputs size is " << inputs.size() << ", is not two";
  76. }
  77. stream_id_ = stream_id;
  78. MS_EXCEPTION_IF_NULL(inputs[0]);
  79. MS_EXCEPTION_IF_NULL(inputs[1]);
  80. auto loop_cnt = inputs[0]->addr;
  81. auto ites_per_loop = inputs[1]->addr;
  82. MS_LOG(INFO) << "cond_:" << static_cast<int>(cond_) << ", true_stream_index_:" << true_stream_index_
  83. << ", stream_id:" << stream_id;
  84. std::shared_ptr<StreamSwitchTaskInfo> task_info_ptr =
  85. std::make_shared<StreamSwitchTaskInfo>(stream_id, true_stream_index_, loop_cnt, ites_per_loop, cond_, data_type_);
  86. MS_EXCEPTION_IF_NULL(task_info_ptr);
  87. return {task_info_ptr};
  88. }
  89. } // namespace kernel
  90. } // namespace mindspore