You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cell.cc 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/api/cell.h"
  17. #include "include/api/context.h"
  18. #include "cxx_api/factory.h"
  19. #include "cxx_api/graph/graph_impl.h"
  20. namespace mindspore {
  21. std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
  22. ParameterCell::ParameterCell(const ParameterCell &cell) {
  23. auto tmp_ptr = cell.tensor_.Clone();
  24. tensor_ = *tmp_ptr;
  25. MSTensor::DestroyTensorPtr(tmp_ptr);
  26. }
  27. ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
  28. if (&cell == this) {
  29. return *this;
  30. }
  31. auto tmp_ptr = cell.tensor_.Clone();
  32. tensor_ = *tmp_ptr;
  33. MSTensor::DestroyTensorPtr(tmp_ptr);
  34. return *this;
  35. }
  36. ParameterCell::ParameterCell(ParameterCell &&cell) : tensor_(cell.tensor_) {}
  37. ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
  38. if (&cell == this) {
  39. return *this;
  40. }
  41. tensor_ = cell.tensor_;
  42. return *this;
  43. }
  44. ParameterCell::ParameterCell(const MSTensor &tensor) {
  45. auto tmp_ptr = tensor.Clone();
  46. tensor_ = *tmp_ptr;
  47. MSTensor::DestroyTensorPtr(tmp_ptr);
  48. }
  49. ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
  50. auto tmp_ptr = tensor.Clone();
  51. tensor_ = *tmp_ptr;
  52. MSTensor::DestroyTensorPtr(tmp_ptr);
  53. return *this;
  54. }
  55. ParameterCell::ParameterCell(MSTensor &&tensor) : tensor_(tensor) {}
  56. ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
  57. tensor_ = tensor;
  58. return *this;
  59. }
  60. GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
  61. GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); }
  62. GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
  63. void GraphCell::SetContext(const std::shared_ptr<Context> &context) {
  64. if (executor_ == nullptr) {
  65. executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
  66. if (executor_ == nullptr) {
  67. MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
  68. return;
  69. }
  70. executor_->SetGraph(graph_);
  71. }
  72. executor_->SetContext(context);
  73. }
  74. Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
  75. if (executor_ == nullptr) {
  76. executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
  77. if (executor_ == nullptr) {
  78. MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
  79. return kMEFailed;
  80. }
  81. executor_->SetGraph(graph_);
  82. }
  83. return executor_->Run(inputs, outputs);
  84. }
  85. Status GraphCell::Load(uint32_t device_id) {
  86. if (executor_ == nullptr) {
  87. executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
  88. if (executor_ == nullptr) {
  89. MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
  90. return kMEFailed;
  91. }
  92. executor_->SetGraph(graph_);
  93. }
  94. return executor_->Load(device_id);
  95. }
  96. std::vector<MSTensor> GraphCell::GetInputs() {
  97. if (executor_ == nullptr) {
  98. executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
  99. if (executor_ == nullptr) {
  100. MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
  101. return {};
  102. }
  103. executor_->SetGraph(graph_);
  104. }
  105. return executor_->GetInputs();
  106. }
  107. std::vector<MSTensor> GraphCell::GetOutputs() {
  108. if (executor_ == nullptr) {
  109. executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
  110. if (executor_ == nullptr) {
  111. MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
  112. return {};
  113. }
  114. executor_->SetGraph(graph_);
  115. }
  116. return executor_->GetOutputs();
  117. }
  118. InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
  119. InputAndOutput::InputAndOutput(const MSTensor &tensor) : prev_(), index_(-1) {
  120. auto tmp_ptr = tensor.Clone();
  121. cell_ = std::make_shared<ParameterCell>(*tmp_ptr);
  122. MSTensor::DestroyTensorPtr(tmp_ptr);
  123. }
  124. InputAndOutput::InputAndOutput(MSTensor &&tensor)
  125. : cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}
  126. InputAndOutput::InputAndOutput(const std::shared_ptr<CellBase> &cell, const std::vector<InputAndOutput> &prev,
  127. int32_t index)
  128. : cell_(cell), prev_(prev), index_(index) {}
  129. } // namespace mindspore