You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_trait.cpp 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. /**
  2. * \file imperative/src/impl/op_trait.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <sstream>
  12. #include "megbrain/imperative/ops/opr_attr.h"
  13. #include "./op_trait.h"
  14. #include "./proxy_graph_detail.h"
  15. namespace mgb {
  16. namespace imperative {
  17. namespace detail {
  18. struct StaticData {
  19. std::list<OpTrait> registries;
  20. std::unordered_map<const char*, OpTrait*> name2reg;
  21. std::unordered_map<Typeinfo*, OpTrait*> type2reg;
  22. };
  23. // use "Construct On First Use" to prevent "static initialization order fiasco"
  24. // (i.e., ensure global registry was initialized before calling opr registration)
  25. StaticData& static_data() {
  26. static StaticData data;
  27. return data;
  28. }
  29. template<typename T>
  30. struct __not_implementation__;
  31. template<typename RType, typename ...Args>
  32. struct __not_implementation__<RType(Args...)> {
  33. static RType raise(Args ...) {
  34. mgb_throw(MegBrainError, "Not Implemented");
  35. }
  36. };
  37. } // detail
  38. OpTrait::OpTrait(const char* name_): name(name_) {}
  39. OpTrait* OpTrait::find_by_typeinfo(Typeinfo* type) {
  40. auto&& type2reg = detail::static_data().type2reg;
  41. auto iter = type2reg.find(type);
  42. if (iter == type2reg.end()) {
  43. return nullptr;
  44. }
  45. return iter->second;
  46. }
  47. OpTrait* OpTrait::find_by_name(const char* name) {
  48. auto&& name2reg = detail::static_data().name2reg;
  49. auto iter = name2reg.find(name);
  50. if (iter == name2reg.find(name)) {
  51. return nullptr;
  52. }
  53. return iter->second;
  54. }
  55. void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
  56. for(auto& trait: detail::static_data().registries){
  57. visitor(trait);
  58. }
  59. }
  60. OpTraitRegistry& OpTraitRegistry::finalize() {
  61. std::ostringstream msg;
  62. #define CHECK(field) if (!trait->field) { \
  63. msg << ", " #field; \
  64. trait->field = \
  65. detail::__not_implementation__<decltype(OpDef::field)>::raise; \
  66. }
  67. CHECK(make_from_op_node);
  68. CHECK(apply_on_physical_tensor);
  69. CHECK(exec);
  70. CHECK(apply_on_var_node);
  71. CHECK(infer_output_attrs_fallible);
  72. CHECK(infer_output_attrs);
  73. CHECK(make_backward_graph);
  74. #undef CHECK
  75. if (msg.tellp() > 0) {
  76. mgb_log_warn(
  77. "%s op trait missing: %s",
  78. trait->name ? trait->name : "(anonymous)",
  79. msg.str().c_str() + 2 /* skip first ", " */);
  80. }
  81. return *this;
  82. }
  83. SmallVector<TensorPtr> fallback_apply_on_physical_tensor(
  84. const OpDef& def,
  85. const SmallVector<TensorPtr>& inputs) {
  86. auto desc = OpDef::infer_output_attrs(def, inputs);
  87. SmallVector<TensorPtr> outputs;
  88. for (auto&& i : desc) {
  89. outputs.push_back(Tensor::make(i.layout, i.comp_node));
  90. }
  91. OpDef::exec(def, inputs, outputs);
  92. return outputs;
  93. }
  94. SmallVector<LogicalTensorDesc> fallback_infer_output_attrs(const OpDef& def,
  95. const SmallVector<TensorPtr>& inputs){
  96. SmallVector<LogicalTensorDesc> input_descs;
  97. for(auto&& input: inputs){
  98. input_descs.push_back({input->layout(), input->comp_node()});
  99. }
  100. return input_descs;
  101. }
  102. OpTraitRegistry& OpTraitRegistry::fallback() {
  103. if (!trait->exec && trait->apply_on_var_node) {
  104. trait->exec = proxy_graph_detail::exec;
  105. }
  106. if (!trait->infer_output_attrs && trait->apply_on_var_node) {
  107. trait->infer_output_attrs = proxy_graph_detail::infer_output_attrs;
  108. }
  109. if (!trait->infer_output_attrs_fallible && trait->apply_on_var_node) {
  110. trait->infer_output_attrs_fallible = proxy_graph_detail::infer_output_attrs_fallible;
  111. }
  112. if (!trait->make_backward_graph && trait->apply_on_var_node) {
  113. trait->make_backward_graph = proxy_graph_detail::make_backward_graph;
  114. }
  115. if (!trait->apply_on_physical_tensor && trait->infer_output_attrs && trait->exec) {
  116. trait->apply_on_physical_tensor = fallback_apply_on_physical_tensor;
  117. }
  118. if(!trait->infer_output_attrs && trait->infer_output_attrs_fallible){
  119. trait->infer_output_attrs = fallback_infer_output_attrs;
  120. }
  121. return *this;
  122. }
  123. void OpTraitRegistry::do_insert(Typeinfo* type) {
  124. auto&& sd = detail::static_data();
  125. mgb_assert(sd.type2reg.emplace(type, trait).second);
  126. }
  127. OpTraitRegistry OpTraitRegistry::do_insert(const char* name) {
  128. auto&& sd = detail::static_data();
  129. if (name) {
  130. mgb_assert(!sd.name2reg.count(name),
  131. "duplicated opr trait %s", name);
  132. }
  133. sd.registries.emplace_back(name);
  134. auto ret = &sd.registries.back();
  135. sd.name2reg.emplace(name, ret);
  136. return {ret};
  137. }
  138. } // namespace imperative
  139. } // namespace mgb
  140. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台