You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_def.cpp 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. /**
  2. * \file imperative/src/impl/op_def.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "megbrain/imperative/op_def.h"
  12. #include "megbrain/imperative/ops/opr_attr.h"
  13. #include "./op_trait.h"
  14. namespace mgb {
  15. namespace imperative {
  16. std::shared_ptr<OpDef> OpDef::make_from_op_node(
  17. cg::OperatorNodeBase* node) {
  18. OpTrait* trait;
  19. trait = OpTrait::find_by_typeinfo(node->dyn_typeinfo());
  20. if (!trait) {
  21. // TODO: register `make_from_op_node` for each OperatorNode
  22. // instead of forwarding to OprAttr
  23. trait = OpTrait::find_by_typeinfo(OprAttr::typeinfo());
  24. }
  25. mgb_assert(trait);
  26. return trait->make_from_op_node(node);
  27. }
  28. DispatchMode OpDef::decide_dispatch_mode(
  29. const OpDef& def,
  30. const SmallVector<LogicalTensorDesc>& inputs) {
  31. return def.trait()->decide_dispatch_mode(def, inputs);
  32. }
  33. SmallVector<TensorPtr> OpDef::apply_on_physical_tensor(
  34. const OpDef& def,
  35. SmallVector<TensorPtr> inputs) {
  36. return def.trait()->apply_on_physical_tensor(def, std::move(inputs));
  37. }
  38. void OpDef::apply_on_device_tensornd(
  39. const OpDef& def,
  40. const SmallVector<DeviceTensorND>& inputs,
  41. SmallVector<DeviceTensorND>* outputs) {
  42. def.trait()->apply_on_device_tensornd(def, inputs, outputs);
  43. return;
  44. }
  45. VarNodeArray OpDef::apply_on_var_node(
  46. const OpDef& def,
  47. const VarNodeArray& inputs) {
  48. return def.trait()->apply_on_var_node(def, inputs);
  49. }
  50. std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible(
  51. const OpDef& def,
  52. const SmallVector<LogicalTensorDesc>& inputs) {
  53. return def.trait()->infer_output_attrs_fallible(def, inputs);
  54. }
  55. BackwardGraphResult OpDef::make_backward_graph(
  56. const OpDef& def,
  57. const SmallVector<LogicalTensorDesc>& inputs,
  58. const SmallVector<bool>& input_requires_grad,
  59. const SmallVector<bool>& output_has_grad) {
  60. return def.trait()->make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
  61. }
  62. size_t OpDef::hash() const {
  63. return trait()->hash(*this);
  64. }
  65. bool OpDef::is_same_st(const Hashable& rhs) const {
  66. return trait()->is_same_st(*this, static_cast<const OpDef&>(rhs));
  67. }
  68. const OpTrait* OpDef::trait() const {
  69. if (!m_trait) {
  70. m_trait = OpTrait::find_by_typeinfo(dyn_typeinfo());
  71. mgb_throw_if(!m_trait, MegBrainError,
  72. "can not find op_trait by %s", dyn_typeinfo()->name);
  73. }
  74. return m_trait;
  75. }
  76. } // namespace imperative
  77. } // namespace mgb
  78. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台