You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

symbol_var.cpp 4.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /**
  2. * \file src/core/impl/graph/symbol_var.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./cg_impl.h"
  12. #include "megbrain/graph/symbol_var.h"
  13. #include "megbrain/opr/basic_arith.h"
  14. #include "megbrain/opr/basic_arith_wrapper.h"
  15. #include "megbrain/opr/tensor_manip.h"
  16. #include "megbrain/opr/io.h"
  17. using namespace mgb;
  18. using namespace cg;
  19. SymbolVar SymbolVar::rename(const std::string &name) const {
  20. m_node->name(name);
  21. return *this;
  22. }
  23. SymbolVar SymbolVar::symshape() const {
  24. return opr::GetVarShape::make(*this);
  25. }
  26. SymbolVar SymbolVar::reshape(const TensorShape &tshape) const {
  27. return opr::Reshape::make(*this, tshape);
  28. }
  29. SymbolVar SymbolVar::reshape(SymbolVar tshape) const {
  30. return opr::Reshape::make(*this, tshape);
  31. }
  32. SymbolVar SymbolVar::broadcast(const TensorShape &tshape) const {
  33. return opr::Broadcast::make(*this, tshape);
  34. }
  35. SymbolVar SymbolVar::broadcast(SymbolVar tshape) const {
  36. return opr::Broadcast::make(*this, tshape);
  37. }
  38. SymbolVar SymbolVar::flatten() const {
  39. return opr::Reshape::make(*this, make_scalar(1), 0);
  40. }
  41. SymbolVar SymbolVar::add_axis(size_t idx) const {
  42. return opr::AxisAddRemove::make(*this,
  43. {opr::AxisAddRemove::AxisDesc::make_add(idx)});
  44. }
  45. Maybe<DTypeScalar> SymbolVar::as_immutable_scalar() const {
  46. using IT = static_infer::InferType;
  47. auto &&mgr = node()->owner_graph()->static_infer_manager();
  48. auto ivar = node();
  49. for (; ; ) {
  50. auto ivar_type = ivar->owner_opr()->dyn_typeinfo();
  51. if (ivar_type == opr::Broadcast::typeinfo() ||
  52. ivar_type == opr::Reshape::typeinfo()) {
  53. ivar = ivar->owner_opr()->input(0);
  54. } else {
  55. break;
  56. }
  57. }
  58. auto it = mgr.get_infer_type(ivar);
  59. if (it.value & IT::CONST) {
  60. DeviceTensorND ival = mgr.infer_value(ivar);
  61. // remove boradcasted axis
  62. auto layout = ival.layout();
  63. for (int i = layout.ndim - 1; i >= 0; -- i) {
  64. if (!layout.stride[i] && layout.ndim >= 2)
  65. layout.remove_axis_inplace(i);
  66. }
  67. if (layout.is_scalar() || (layout.ndim == 1 && !layout.stride[0])) {
  68. return DTypeScalar::make_from_raw(ival.dtype(), ival.raw_ptr());
  69. }
  70. }
  71. return None;
  72. }
  73. Maybe<DTypeScalar> SymbolVar::as_immutable_scalar_require_shape() const {
  74. if (!shape().is_scalar())
  75. return None;
  76. return as_immutable_scalar();
  77. }
  78. SymbolVar SymbolVar::operator + (const SymbolVar &rhs) const {
  79. return opr::add(*this, rhs);
  80. }
  81. SymbolVar SymbolVar::operator - (const SymbolVar &rhs) const {
  82. return opr::sub(*this, rhs);
  83. }
  84. SymbolVar SymbolVar::operator * (const SymbolVar &rhs) const {
  85. return opr::mul(*this, rhs);
  86. }
  87. SymbolVar SymbolVar::operator / (const SymbolVar &rhs) const {
  88. if (dtype().category() == DTypeCategory::INT &&
  89. rhs.dtype().category() == DTypeCategory::INT) {
  90. return opr::floor_div(*this, rhs);
  91. }
  92. return opr::div(*this, rhs);
  93. }
  94. SymbolVar SymbolVar::operator < (const SymbolVar &rhs) const {
  95. return opr::less_than(*this, rhs);
  96. }
  97. SymbolVar SymbolVar::operator <= (const SymbolVar &rhs) const {
  98. return opr::less_equal(*this, rhs);
  99. }
  100. SymbolVar SymbolVar::operator - () const {
  101. return opr::negate(*this);
  102. }
  103. SymbolVar SymbolVar::make_scalar(
  104. DTypeScalar value, ComputingGraph &cg, CompNode cn) {
  105. return opr::ImmutableTensor::make(cg, value, {cn});
  106. }
  107. const DeviceTensorND& SymbolVar::eager_eval_get_value() const {
  108. #if MGB_BUILD_SLIM_SERVING
  109. mgb_throw(MegBrainError, "eager eval disabled at compile time");
  110. #else
  111. auto og = ComputingGraphImpl::downcast(node()->owner_graph());
  112. mgb_assert(og->options().eager_evaluation);
  113. return node()->dev_tensor();
  114. #endif
  115. }
  116. void VarNodeArrayView::check_idx(size_t idx) const {
  117. mgb_assert(m_begin + idx < m_end, "idx out of range: %zu/%td", idx,
  118. m_end - m_begin);
  119. }
  120. void SymbolVarArrayView::check_idx(size_t idx) const {
  121. mgb_assert(m_begin + idx < m_end, "idx out of range: %zu/%td", idx,
  122. m_end - m_begin);
  123. }
  124. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台