You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_operators.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. #include "megbrain/imperative/basic_operators.h"
  2. #include "megbrain/imperative/basic_values.h"
  3. namespace mgb {
  4. namespace imperative {
  5. std::string ApplyOp::to_string() const {
  6. return m_op.to_string();
  7. }
  8. std::string GetAttr::to_string() const {
  9. std::string buffer;
  10. const char* attr_name = ([&] {
  11. switch (m_attr) {
  12. case None:
  13. return "None";
  14. case DType:
  15. return "DType";
  16. case Device:
  17. return "Device";
  18. case Shape:
  19. return "Shape";
  20. case Value:
  21. return "Value";
  22. case Data:
  23. return "Data";
  24. default:
  25. buffer = std::to_string(m_attr);
  26. return buffer.c_str();
  27. }
  28. })();
  29. return ssprintf("GetAttr{attr=%s}", attr_name);
  30. }
  31. CreateTensor::CreateTensor(
  32. Kind kind, CompNode device, DType dtype, ValueShape shape, Format format)
  33. : m_kind(kind),
  34. m_device(device),
  35. m_dtype(dtype),
  36. m_shape(shape),
  37. m_format(format) {}
  38. CreateTensor::CreateTensor(Kind kind, CompNode device, TensorLayout layout)
  39. : m_kind(kind),
  40. m_device(device),
  41. m_dtype(layout.dtype),
  42. m_shape(ValueShape::from(layout)),
  43. m_format(Format::Type::DEFAULT) {
  44. mgb_assert(
  45. layout.is_contiguous() || layout.is_empty(), "layout should be contiguous");
  46. }
  47. auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args {
  48. Args result;
  49. for (auto&& input : inputs) {
  50. if (auto host_storage = input.as_ref<HostStorage>()) {
  51. mgb_assert(!result.host, "duplicated host value");
  52. result.host.emplace();
  53. result.host->reset(*host_storage, {shape().as_tensor_shape(), dtype()});
  54. mgb_assert(result.host->layout().ndim, "invalid shape");
  55. } else if (auto device_storage = input.as_ref<DeviceStorage>()) {
  56. mgb_assert(!result.device, "duplicated device value");
  57. result.device.emplace(device(), shape().as_tensor_shape(), dtype());
  58. result.device->reset(*device_storage, {shape().as_tensor_shape(), dtype()});
  59. mgb_assert(result.device->layout().ndim, "invalid shape");
  60. } else {
  61. mgb_throw(
  62. MegBrainError,
  63. "unknown input type, expects HostStorage or DeviceStorage, got "
  64. "%s",
  65. input.to_string().c_str());
  66. }
  67. }
  68. mgb_assert(
  69. result.host || result.device, "require at least one of host/device value");
  70. result.kind = kind();
  71. return result;
  72. }
  73. std::string CreateTensor::to_string() const {
  74. return ssprintf(
  75. "CreateTensor{kind=%d, device=%s, dtype=%s, shape=%s, format=%s}",
  76. (int)m_kind, m_device.to_string().c_str(), m_dtype.name(),
  77. m_shape.to_string().c_str(), m_format.to_string().c_str());
  78. }
  79. std::string DTRCommand::to_string() const {
  80. return ssprintf("DTRCommandValue{kind=%d}", (int)m_kind);
  81. }
  82. std::string CreateNode::to_string() const {
  83. return "CreateNode";
  84. }
  85. std::string GetName::to_string() const {
  86. return "GetName{}";
  87. }
  88. std::string RenameValue::to_string() const {
  89. return ssprintf("RenameValue{name=%s}", imperative::quoted(m_name).c_str());
  90. }
  91. std::string IsScalar::to_string() const {
  92. return "IsScalar";
  93. }
  94. std::string GetFormat::to_string() const {
  95. return "GetFormat{}";
  96. }
  97. std::string SetFormat::to_string() const {
  98. return ssprintf("SetFormat{format=%s}", m_format.to_string().c_str());
  99. }
  100. std::string GetVarVal::to_string() const {
  101. return "GetVarVal";
  102. }
  103. } // namespace imperative
  104. } // namespace mgb