You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_sanity_check.cpp 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * \file src/core/impl/imperative/tensor_sanity_check.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/tensor_sanity_check.h"
  12. #include "./op_trait.h"
  13. namespace mgb {
  14. namespace imperative {
  15. TensorChecksumCalc::ChecksumResult TensorChecksumCalc::calc(TensorPtr ptr) {
  16. auto&& dt = ptr->dev_tensor();
  17. if (!dt.layout().total_nr_elems()) {
  18. static ChecksumResult empty_checksum;
  19. return empty_checksum;
  20. }
  21. auto span = dt.layout().span();
  22. megdnn::TensorND tensor;
  23. tensor.reset_ptr(dt.raw_ptr() + span.low_byte);
  24. tensor.layout.init_contiguous_stride({span.dist_byte()});
  25. tensor.layout.dtype = dtype::Byte();
  26. DeviceTensorStorage* workspace;
  27. {
  28. MGB_LOCK_GUARD(m_workspace_mtx);
  29. workspace = &m_workspace[std::this_thread::get_id()].storage[ptr->comp_node()];
  30. }
  31. auto comp_node = ptr->comp_node();
  32. comp_node.activate();
  33. auto opr = opr::intl::get_megdnn_global_opr<megdnn::Checksum>(comp_node);
  34. auto workspace_reqsize = opr->get_workspace_in_bytes(tensor.layout);
  35. workspace->comp_node(ptr->comp_node()).ensure_size(workspace_reqsize);
  36. megdnn::Workspace mwk;
  37. if (workspace_reqsize)
  38. mwk = {workspace->ptr(), workspace_reqsize};
  39. return opr->exec(tensor, mwk);
  40. }
  41. class TensorSanityCheckImpl {
  42. public:
  43. std::vector<std::tuple<OpTrait*, std::unique_ptr<ApplyOnPhysicalTensor>>> hook_list;
  44. std::unordered_map<TensorPtr, TensorChecksumCalc::ChecksumResult>
  45. tensor2chksum; // TODO: may increase device memory overhead
  46. TensorSanityCheckImpl() { m_calc = std::make_unique<TensorChecksumCalc>(); }
  47. bool check(TensorPtr p);
  48. private:
  49. std::unique_ptr<TensorChecksumCalc> m_calc;
  50. };
  51. bool TensorSanityCheckImpl::check(TensorPtr p) {
  52. auto&& it = tensor2chksum.find(p);
  53. auto&& chksum = m_calc->calc(p);
  54. if (it == tensor2chksum.end()) {
  55. tensor2chksum[p] = chksum;
  56. return true;
  57. }
  58. return it->second == chksum;
  59. }
  60. void TensorSanityCheck::enable() {
  61. CompNode::sync_all();
  62. OpTrait::for_each_trait([this](OpTrait& trait) {
  63. auto backup = std::make_unique<ApplyOnPhysicalTensor>(
  64. std::move(trait.apply_on_physical_tensor));
  65. trait.apply_on_physical_tensor = ApplyOnPhysicalTensor(
  66. [this, backup = backup.get()](
  67. const OpDef& def, const SmallVector<TensorPtr>& inputs) {
  68. for (auto&& i : inputs) {
  69. if (!m_checker->check(i)) {
  70. mgb_throw(
  71. TensorChecksumCalc::Error,
  72. "tensor modified before exec %s",
  73. print_op(def).c_str());
  74. }
  75. }
  76. auto output = (*backup)(def, inputs);
  77. for (auto&& i : output) {
  78. mgb_assert(m_checker->check(i));
  79. }
  80. for (auto&& i : inputs) {
  81. if (!m_checker->check(i)) {
  82. mgb_throw(
  83. TensorChecksumCalc::Error,
  84. "tensor modified after exec %s",
  85. print_op(def).c_str());
  86. }
  87. }
  88. return output;
  89. });
  90. m_checker->hook_list.push_back({&trait, std::move(backup)});
  91. });
  92. }
  93. void TensorSanityCheck::disable() {
  94. for (auto&& hook : m_checker->hook_list) {
  95. std::get<0>(hook)->apply_on_physical_tensor = std::move(*std::get<1>(hook));
  96. }
  97. m_checker->tensor2chksum.clear();
  98. m_checker->hook_list.clear();
  99. }
  100. TensorSanityCheck::TensorSanityCheck() {
  101. m_checker = std::make_unique<TensorSanityCheckImpl>();
  102. }
  103. TensorSanityCheck::~TensorSanityCheck() {}
  104. std::string TensorSanityCheck::print_op(const OpDef& def) {
  105. auto* opr_attr = def.try_cast_final<const OprAttr>();
  106. if (opr_attr) {
  107. return std::string("OprAttr:") + opr_attr->type;
  108. }
  109. return def.dyn_typeinfo()->name;
  110. }
  111. } // namespace imperative
  112. } // namespace mgb