You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_impl_base.h 3.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. /**
  2. * \file src/tensor_impl_base.h
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #pragma once
  12. #include "lite/tensor.h"
  13. #include "misc.h"
  14. #include "type_info.h"
  15. #include <unordered_map>
  16. namespace lite {
  17. /*!
  18. * \brief implement the Tensor
  19. */
  20. class Tensor::TensorImplBase : public DynTypeObj {
  21. public:
  22. virtual ~TensorImplBase() = default;
  23. virtual LiteDeviceType get_device_type() const = 0;
  24. virtual int get_device_id() const = 0;
  25. virtual LiteBackend get_backend_type() const = 0;
  26. virtual Layout get_layout() const = 0;
  27. virtual bool is_pinned_host() const = 0;
  28. virtual void* get_memory_ptr() const = 0;
  29. virtual void* get_memory_ptr(const std::vector<size_t>& idx) const = 0;
  30. virtual void set_layout(const Layout& layout) = 0;
  31. //! use the user allocated data to reset the memory of the tensor, the
  32. //! memory will not be managed by the lite, later, the user should delete
  33. //! it.
  34. virtual void reset(void* prepared_data) = 0;
  35. //! use the user allocated data and corresponding layout to reset the data
  36. //! and layout of the tensor, the memory will not be managed by lite, later,
  37. //! the user should delete it.
  38. virtual void reset(void* prepared_data, const Layout& layout) = 0;
  39. //! reshape the tensor with new shape, keep the data_type the same
  40. virtual void reshape(const Layout& layout) = 0;
  41. //! get a new tensor slice from the origin tensor
  42. virtual std::shared_ptr<Tensor> slice(
  43. const std::vector<size_t>& start, const std::vector<size_t>& end,
  44. const std::vector<size_t>& step = {}) = 0;
  45. //! set the tensor memory with zero
  46. virtual void fill_zero() = 0;
  47. //! copy tensor form other tensor
  48. //! Note: the best way for tensor copy is just set the dst device, left
  49. //! layout empty, when copying the dst layout will be set the same with
  50. //! src
  51. virtual void copy_from(const TensorImplBase* src_impl) = 0;
  52. //! share memory with other tensor
  53. virtual void share_memory_with(const TensorImplBase* src_impl) = 0;
  54. //! whether the memory of tensor is continue
  55. virtual bool is_continue_memory() const = 0;
  56. };
  57. /*!
  58. * \brief friend class of Tensor, for convenient accessing the Network members
  59. */
  60. class TensorHelper {
  61. public:
  62. static inline std::shared_ptr<Tensor::TensorImplBase> implement(
  63. const std::shared_ptr<Tensor> tensor) {
  64. LITE_ASSERT(tensor);
  65. return tensor->m_tensor_impl;
  66. }
  67. static inline std::shared_ptr<Tensor::TensorImplBase> implement(
  68. const Tensor* tensor) {
  69. LITE_ASSERT(tensor);
  70. return tensor->m_tensor_impl;
  71. }
  72. static inline void implement(
  73. const std::shared_ptr<Tensor> tensor,
  74. std::shared_ptr<Tensor::TensorImplBase> impl) {
  75. LITE_ASSERT(tensor);
  76. tensor->m_tensor_impl = impl;
  77. }
  78. };
  79. } // namespace lite
  80. // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}