You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

basic_values.cpp 2.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. #include "megbrain/imperative/basic_values.h"
  2. namespace mgb {
  3. namespace imperative {
  4. std::string ShapeValue::to_string() const {
  5. return ssprintf("ValueShape%s", ValueShape::to_string().c_str());
  6. }
  7. std::string CompNodeValue::to_string() const {
  8. return CompNode::to_string();
  9. }
  10. std::string BoolValue::to_string() const {
  11. return (*this) ? "true" : "false";
  12. }
  13. std::string HostStorage::to_string() const {
  14. return ssprintf("HostStorage{device=%s}", comp_node().to_string().c_str());
  15. }
  16. std::string DeviceStorage::to_string() const {
  17. return ssprintf("DeviceStorage{device=%s}", comp_node().to_string().c_str());
  18. }
  19. std::string HostValue::to_string() const {
  20. return ssprintf(
  21. "HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
  22. dtype().name(), shape().to_string().c_str());
  23. }
  24. HostTensorND HostTensor::as_nd(bool allow_scalar) const {
  25. HostTensorND nd;
  26. TensorShape tensor_shape;
  27. if (m_shape.is_scalar()) {
  28. mgb_assert(allow_scalar);
  29. tensor_shape = TensorShape{1};
  30. } else {
  31. tensor_shape = m_shape.as_tensor_shape();
  32. }
  33. nd.reset(m_storage, {tensor_shape, dtype()});
  34. return nd;
  35. }
  36. std::string DeviceValue::to_string() const {
  37. return ssprintf(
  38. "DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(),
  39. dtype().name(), shape().to_string().c_str());
  40. }
  41. DeviceTensorND DeviceTensor::as_nd(bool allow_scalar) const {
  42. DeviceTensorND nd;
  43. TensorShape tensor_shape;
  44. if (m_shape.is_scalar()) {
  45. mgb_assert(allow_scalar);
  46. tensor_shape = TensorShape{1};
  47. } else {
  48. tensor_shape = m_shape.as_tensor_shape();
  49. }
  50. nd.reset(m_storage, {tensor_shape, dtype()});
  51. return nd;
  52. }
  53. std::string FunctionValue::to_string() const {
  54. return ssprintf("FunctionValue{type=%s}", target_type().name());
  55. }
  56. std::string DTypeValue::to_string() const {
  57. return DType::name();
  58. }
  59. std::string StringValue::to_string() const {
  60. return imperative::quoted((std::string&)*this);
  61. }
  62. std::string ErrorValue::to_string() const {
  63. return ssprintf("ErrorValue{message=%s}", message().c_str());
  64. }
  65. } // namespace imperative
  66. } // namespace mgb