You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

node.cc 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. /**
  2. * Copyright 2021-2022 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_kernel/model/node.h"
  17. #include <sstream>
  18. #include <utility>
  19. namespace mindspore::graphkernel::inner {
  20. void Node::SetBaseInfo(NodeBase baseinfo) {
  21. this->shape = std::move(baseinfo.shape);
  22. this->type = std::move(baseinfo.type);
  23. this->format = std::move(baseinfo.format);
  24. }
  25. std::string Node::ToString() const {
  26. std::ostringstream oss;
  27. oss << debug_name() << "[";
  28. for (size_t i = 0; i < shape.size(); i++) {
  29. oss << shape[i];
  30. if (i + 1 < shape.size()) oss << ",";
  31. }
  32. oss << "]{" << TypeIdToString(type) << "x" << format << "}";
  33. return oss.str();
  34. }
  35. void Node::AddInput(const NodePtr &new_input) {
  36. MS_EXCEPTION_IF_NULL(new_input);
  37. new_input->AddUser(this, inputs_.size());
  38. (void)inputs_.emplace_back(new_input);
  39. }
  40. void Node::SetInput(size_t i, const NodePtr &new_input) {
  41. MS_EXCEPTION_IF_NULL(new_input);
  42. if (i >= inputs_.size()) {
  43. MS_LOG(EXCEPTION) << "The index " << i << " is out of the inputs range [0, " << inputs_.size() << ")";
  44. }
  45. auto &old_input = inputs_[i];
  46. old_input->RemoveUser(this, i);
  47. new_input->AddUser(this, i);
  48. inputs_[i] = new_input;
  49. }
  50. void Node::SetInputs(const NodePtrList &inputs) {
  51. ClearInputs();
  52. inputs_.reserve(inputs.size());
  53. for (const auto &inp : inputs) {
  54. AddInput(inp);
  55. }
  56. }
  57. void Node::ClearInputs() {
  58. if (!inputs_.empty()) {
  59. // remove the original inputs
  60. for (size_t i = 0; i < inputs_.size(); i++) {
  61. inputs_[i]->RemoveUser(this, i);
  62. }
  63. inputs_.clear();
  64. }
  65. }
  66. void Node::ReplaceWith(const NodePtr &other_node) {
  67. if (this->users_.empty()) return;
  68. // the users_ will be changed, so we copy the users before traversal
  69. auto users = this->users_;
  70. for (auto &user : users) {
  71. for (auto idx : user.second) {
  72. user.first->SetInput(idx, other_node);
  73. }
  74. }
  75. }
  76. void Node::RemoveUser(Node *const user, size_t index) {
  77. if (auto iter = users_.find(user); iter != users_.end()) {
  78. (void)iter->second.erase(index);
  79. if (iter->second.empty()) {
  80. (void)users_.erase(iter);
  81. }
  82. }
  83. }
  84. } // namespace mindspore::graphkernel::inner