You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_trait.cpp 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. /**
  2. * \file imperative/src/impl/op_trait.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include <sstream>
  12. #include "megbrain/imperative/ops/opr_attr.h"
  13. #include "./op_trait.h"
  14. #include "./proxy_graph_detail.h"
  15. namespace mgb {
  16. namespace imperative {
  17. namespace detail {
  18. struct StaticData {
  19. std::list<OpTrait> registries;
  20. std::unordered_map<std::string, OpTrait*> name2reg;
  21. std::unordered_map<Typeinfo*, OpTrait*> type2reg;
  22. };
  23. // use "Construct On First Use" to prevent "static initialization order fiasco"
  24. // (i.e., ensure global registry was initialized before calling opr registration)
  25. StaticData& static_data() {
  26. static StaticData data;
  27. return data;
  28. }
  29. } // detail
  30. OpTrait::OpTrait(const char* name_): name(name_) {}
  31. OpTrait* OpTrait::find_by_typeinfo(Typeinfo* type) {
  32. auto&& type2reg = detail::static_data().type2reg;
  33. auto iter = type2reg.find(type);
  34. if (iter == type2reg.end()) {
  35. return nullptr;
  36. }
  37. return iter->second;
  38. }
  39. OpTrait* OpTrait::find_by_name(const char* name) {
  40. auto&& name2reg = detail::static_data().name2reg;
  41. auto iter = name2reg.find(name);
  42. if (iter == name2reg.find(name)) {
  43. return nullptr;
  44. }
  45. return iter->second;
  46. }
  47. void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){
  48. for(auto& trait: detail::static_data().registries){
  49. visitor(trait);
  50. }
  51. }
  52. OpTraitRegistry& OpTraitRegistry::fallback() {
  53. if (trait->apply_on_var_node) {
  54. // fallback to proxy graph impl
  55. if (!trait->apply_on_physical_tensor) {
  56. trait->apply_on_physical_tensor =
  57. proxy_graph_detail::apply_on_physical_tensor;
  58. }
  59. if (!trait->infer_output_attrs_fallible) {
  60. trait->infer_output_attrs_fallible =
  61. proxy_graph_detail::infer_output_attrs_fallible;
  62. }
  63. if (!trait->make_backward_graph) {
  64. trait->make_backward_graph =
  65. proxy_graph_detail::make_backward_graph;
  66. }
  67. }
  68. return *this;
  69. }
  70. void OpTraitRegistry::do_insert(Typeinfo* type) {
  71. auto&& sd = detail::static_data();
  72. auto ret = sd.type2reg.emplace(type, trait);
  73. mgb_assert(ret.second || ret.first->second == trait,
  74. "OpTrait for %s has already been registered", type->name);
  75. }
  76. OpTraitRegistry OpTraitRegistry::do_insert(const char* name) {
  77. auto&& sd = detail::static_data();
  78. if (name) {
  79. auto iter = sd.name2reg.find(name);
  80. if (iter != sd.name2reg.end()) {
  81. return {iter->second};
  82. }
  83. }
  84. sd.registries.emplace_back(name);
  85. auto ret = &sd.registries.back();
  86. if (name) {
  87. sd.name2reg.emplace(name, ret);
  88. }
  89. return {ret};
  90. }
  91. } // namespace imperative
  92. } // namespace mgb
  93. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台