You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_sanity_check.cpp 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. /**
  2. * \file src/core/impl/imperative/tensor_sanity_check.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/tensor_sanity_check.h"
  12. #include "./op_trait.h"
  13. namespace mgb {
  14. namespace imperative {
  15. TensorChecksumCalc::ChecksumResult TensorChecksumCalc::calc(TensorPtr ptr) {
  16. auto&& dt = ptr->dev_tensor();
  17. if (!dt.layout().total_nr_elems()) {
  18. static ChecksumResult empty_checksum;
  19. return empty_checksum;
  20. }
  21. auto span = dt.layout().span();
  22. megdnn::TensorND tensor;
  23. tensor.reset_ptr(dt.raw_ptr() + span.low_byte);
  24. tensor.layout.init_contiguous_stride({span.dist_byte()});
  25. tensor.layout.dtype = dtype::Byte();
  26. DeviceTensorStorage* workspace;
  27. {
  28. MGB_LOCK_GUARD(m_workspace_mtx);
  29. workspace = &m_workspace[std::this_thread::get_id()].storage[ptr->comp_node()];
  30. }
  31. auto comp_node = ptr->comp_node();
  32. comp_node.activate();
  33. auto opr = opr::intl::get_megdnn_global_opr<megdnn::Checksum>(comp_node);
  34. auto workspace_reqsize = opr->get_workspace_in_bytes(tensor.layout);
  35. workspace->comp_node(ptr->comp_node()).ensure_size(workspace_reqsize);
  36. megdnn::Workspace mwk;
  37. if (workspace_reqsize)
  38. mwk = {workspace->ptr(), workspace_reqsize};
  39. return opr->exec(tensor, mwk);
  40. }
  41. class TensorSanityCheckImpl {
  42. public:
  43. std::vector<std::tuple<OpTrait*, std::unique_ptr<ApplyOnPhysicalTensor>>> hook_list;
  44. std::unordered_map<TensorPtr, TensorChecksumCalc::ChecksumResult>
  45. tensor2chksum; // TODO: may increase device memory overhead
  46. TensorSanityCheckImpl() { m_calc = std::make_unique<TensorChecksumCalc>(); }
  47. bool check(TensorPtr p);
  48. private:
  49. std::unique_ptr<TensorChecksumCalc> m_calc;
  50. };
  51. bool TensorSanityCheckImpl::check(TensorPtr p) {
  52. auto&& it = tensor2chksum.find(p);
  53. auto&& chksum = m_calc->calc(p);
  54. if (it == tensor2chksum.end()) {
  55. tensor2chksum[p] = chksum;
  56. return true;
  57. }
  58. return it->second == chksum;
  59. }
  60. void TensorSanityCheck::enable() {
  61. CompNode::sync_all();
  62. OpTrait::for_each_trait([this](OpTrait& trait) {
  63. auto backup = std::make_unique<ApplyOnPhysicalTensor>(
  64. std::move(trait.apply_on_physical_tensor));
  65. trait.apply_on_physical_tensor = ApplyOnPhysicalTensor(
  66. [this, backup = backup.get()](
  67. const OpDef& def, const SmallVector<TensorPtr>& inputs,
  68. SmallVector<LogicalTensorDesc>& output_descs,
  69. const bool& validated) {
  70. for (auto&& i : inputs) {
  71. if (!m_checker->check(i)) {
  72. mgb_throw(
  73. TensorChecksumCalc::Error,
  74. "tensor modified before exec %s",
  75. print_op(def).c_str());
  76. }
  77. }
  78. auto output = (*backup)(def, inputs, output_descs, validated);
  79. for (auto&& i : output) {
  80. mgb_assert(m_checker->check(i));
  81. }
  82. for (auto&& i : inputs) {
  83. if (!m_checker->check(i)) {
  84. mgb_throw(
  85. TensorChecksumCalc::Error,
  86. "tensor modified after exec %s",
  87. print_op(def).c_str());
  88. }
  89. }
  90. return output;
  91. });
  92. m_checker->hook_list.push_back({&trait, std::move(backup)});
  93. });
  94. }
  95. void TensorSanityCheck::disable() {
  96. for (auto&& hook : m_checker->hook_list) {
  97. std::get<0>(hook)->apply_on_physical_tensor = std::move(*std::get<1>(hook));
  98. }
  99. m_checker->tensor2chksum.clear();
  100. m_checker->hook_list.clear();
  101. }
  102. TensorSanityCheck::TensorSanityCheck() {
  103. m_checker = std::make_unique<TensorSanityCheckImpl>();
  104. }
  105. TensorSanityCheck::~TensorSanityCheck() {}
  106. std::string TensorSanityCheck::print_op(const OpDef& def) {
  107. auto* opr_attr = def.try_cast_final<const OprAttr>();
  108. if (opr_attr) {
  109. return std::string("OprAttr:") + opr_attr->type;
  110. }
  111. return def.dyn_typeinfo()->name;
  112. }
  113. } // namespace imperative
  114. } // namespace mgb