You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_trait.cpp 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. /**
  2. * \file imperative/src/impl/op_trait.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <exception>
  12. #include <sstream>
  13. #include <stdexcept>
  14. #include "megbrain/imperative/op_def.h"
  15. #include "megbrain/imperative/ops/opr_attr.h"
  16. #include "megbrain/imperative/proxy_graph_detail.h"
  17. #include "megbrain/tensor.h"
  18. #include "./op_trait.h"
  19. namespace mgb {
  20. namespace imperative {
  21. namespace detail {
  22. struct StaticData {
  23. std::list<OpTrait> registries;
  24. std::unordered_map<std::string, OpTrait*> name2reg;
  25. std::unordered_map<Typeinfo*, OpTrait*> type2reg;
  26. };
  27. // use "Construct On First Use" to prevent "static initialization order fiasco"
  28. // (i.e., ensure global registry was initialized before calling opr registration)
  29. StaticData& static_data() {
  30. static StaticData data;
  31. return data;
  32. }
  33. } // detail
  34. OpTrait::OpTrait(const char* name_): name(name_) {}
  35. OpTrait* OpTrait::find_by_typeinfo(Typeinfo* type) {
  36. auto&& type2reg = detail::static_data().type2reg;
  37. auto iter = type2reg.find(type);
  38. if (iter == type2reg.end()) {
  39. return nullptr;
  40. }
  41. return iter->second;
  42. }
  43. OpTrait* OpTrait::find_by_name(const char* name) {
  44. auto&& name2reg = detail::static_data().name2reg;
  45. auto iter = name2reg.find(name);
  46. if (iter == name2reg.find(name)) {
  47. return nullptr;
  48. }
  49. return iter->second;
  50. }
  51. void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
  52. for(auto& trait: detail::static_data().registries){
  53. visitor(trait);
  54. }
  55. }
  56. DispatchMode fallback_decide_dispatch_mode(
  57. const OpDef& def,
  58. const SmallVector<LogicalTensorDesc>& inputs) {
  59. return KERNEL;
  60. }
  61. OpTraitRegistry& OpTraitRegistry::fallback() {
  62. if (trait->apply_on_var_node) {
  63. // fallback to proxy graph impl
  64. if (!trait->apply_on_physical_tensor) {
  65. trait->apply_on_physical_tensor =
  66. proxy_graph_detail::apply_on_physical_tensor;
  67. }
  68. if (!trait->execute) {
  69. trait->execute = proxy_graph_detail::execute;
  70. }
  71. if (!trait->infer_output_mem_desc) {
  72. trait->infer_output_mem_desc =
  73. proxy_graph_detail::infer_output_mem_desc;
  74. }
  75. if (!trait->infer_output_attrs_fallible) {
  76. trait->infer_output_attrs_fallible =
  77. proxy_graph_detail::infer_output_attrs_fallible;
  78. }
  79. if (!trait->make_backward_graph) {
  80. trait->make_backward_graph =
  81. proxy_graph_detail::make_backward_graph;
  82. }
  83. }
  84. if (!trait->decide_dispatch_mode) {
  85. trait->decide_dispatch_mode = fallback_decide_dispatch_mode;
  86. }
  87. return *this;
  88. }
  89. void OpTraitRegistry::do_insert(Typeinfo* type) {
  90. auto&& sd = detail::static_data();
  91. auto ret = sd.type2reg.emplace(type, trait);
  92. mgb_assert(ret.second || ret.first->second == trait,
  93. "OpTrait for %s has already been registered", type->name);
  94. }
  95. OpTraitRegistry OpTraitRegistry::do_insert(const char* name) {
  96. auto&& sd = detail::static_data();
  97. if (name) {
  98. auto iter = sd.name2reg.find(name);
  99. if (iter != sd.name2reg.end()) {
  100. return {iter->second};
  101. }
  102. }
  103. sd.registries.emplace_back(name);
  104. auto ret = &sd.registries.back();
  105. if (name) {
  106. sd.name2reg.emplace(name, ret);
  107. }
  108. return {ret};
  109. }
  110. } // namespace imperative
  111. } // namespace mgb
  112. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台