You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "include/tensor.h"
  17. #include "common/mslog.h"
  18. #include "src/op_common.h"
  19. #include "include/errorcode.h"
  20. #include "securec/include/securec.h"
  21. #include "common/common.h"
  22. #include "src/runtime/allocator.h"
  23. namespace mindspore {
  24. namespace predict {
  25. Tensor *Tensor::CopyFromTensorDef(const TensorDef &tensorDef) {
  26. std::vector<int64_t> dims;
  27. if (tensorDef.dims() == nullptr) {
  28. MS_LOGD("tensorDef->dims is nullptr");
  29. } else {
  30. MS_ASSERT(tensorDef.dims()->data() != nullptr);
  31. for (uint32_t j = 0; j < tensorDef.dims()->size(); j++) {
  32. dims.push_back(tensorDef.dims()->data()[j]);
  33. }
  34. }
  35. auto tensor =
  36. std::unique_ptr<Tensor>(new (std::nothrow) Tensor(tensorDef.dataType(), dims, tensorDef.format(), nullptr));
  37. if (tensor == nullptr) {
  38. MS_LOGE("new Tensor failed");
  39. return nullptr;
  40. }
  41. if (tensorDef.refCount() == MSConst_WEIGHT_REFCOUNT && tensorDef.data() != nullptr && tensorDef.data()->size() > 0) {
  42. if (dims.size() < 1) {
  43. tensor->SetDims({1});
  44. }
  45. auto ret = tensor->MallocData();
  46. if (ret != RET_OK) {
  47. MS_LOGE("malloc data fail,datasize %zu", tensor->GetDataSize());
  48. return nullptr;
  49. }
  50. auto tensorData = tensorDef.data()->data();
  51. ret = memcpy_sp(tensor->GetData(), tensor->GetDataSize(), tensorData, tensorDef.data()->size());
  52. if (ret != RET_OK) {
  53. MS_LOGE("copy data fail,dst size %zu, src size %u", tensor->GetDataSize(), tensorDef.data()->size());
  54. return nullptr;
  55. }
  56. }
  57. tensor->refCount = tensorDef.refCount();
  58. return tensor.release();
  59. }
  60. Tensor::Tensor(const Tensor &tensor, bool copyData) {
  61. format = tensor.format;
  62. dlTensor.data = nullptr;
  63. dlTensor.ctx.device_type = tensor.dlTensor.ctx.device_type;
  64. dlTensor.ctx.device_id = tensor.dlTensor.ctx.device_id;
  65. dlTensor.strides = nullptr;
  66. dlTensor.byte_offset = tensor.dlTensor.byte_offset;
  67. dlTensor.dtype.code = tensor.dlTensor.dtype.code;
  68. dlTensor.dtype.bits = tensor.dlTensor.dtype.bits;
  69. dlTensor.dtype.lanes = tensor.dlTensor.dtype.lanes;
  70. dlTensor.ndim = tensor.dlTensor.ndim;
  71. if (dlTensor.ndim > 0) {
  72. dlTensor.shape = new (std::nothrow) int64_t[dlTensor.ndim];
  73. if (dlTensor.shape != nullptr) {
  74. for (int i = 0; i < dlTensor.ndim; i++) {
  75. dlTensor.shape[i] = tensor.dlTensor.shape[i];
  76. }
  77. } else {
  78. MS_LOGW("new shape fail,ndim %d", dlTensor.ndim);
  79. }
  80. } else {
  81. dlTensor.shape = nullptr;
  82. }
  83. if (copyData) {
  84. allocator = tensor.allocator;
  85. refCount = tensor.refCount;
  86. auto ret = MallocData();
  87. if (ret != RET_OK) {
  88. return;
  89. }
  90. size_t datasize = GetDataSize();
  91. ret = memcpy_sp(dlTensor.data, datasize, tensor.dlTensor.data, datasize);
  92. if (ret != RET_OK) {
  93. return;
  94. }
  95. }
  96. }
  97. Tensor::Tensor(DataType dt, const std::vector<int64_t> &dims, Format format, void *data) {
  98. this->format = format;
  99. dlTensor.data = data;
  100. dlTensor.ctx.device_type = DLDeviceType::kDLCPU;
  101. dlTensor.ctx.device_id = 0;
  102. dlTensor.strides = nullptr;
  103. dlTensor.byte_offset = 0;
  104. dlTensor.ndim = static_cast<int>(dims.size());
  105. if (dlTensor.ndim > 0) {
  106. dlTensor.shape = new (std::nothrow) int64_t[dlTensor.ndim];
  107. if (dlTensor.shape != nullptr) {
  108. for (int i = 0; i < dlTensor.ndim; i++) {
  109. dlTensor.shape[i] = dims[i];
  110. }
  111. } else {
  112. MS_LOGW("new shape fail,ndim %d", dlTensor.ndim);
  113. }
  114. } else {
  115. dlTensor.shape = nullptr;
  116. }
  117. SetDataType(dt);
  118. }
  119. Tensor::~Tensor() { FreeTensor(); }
  120. DLDataType Tensor::GetTensorDtype() const { return dlTensor.dtype; }
  121. void *Tensor::GetData() const { return dlTensor.data; }
  122. void Tensor::SetData(void *data) { dlTensor.data = data; }
  123. DataType Tensor::GetDataType() const {
  124. DataType dataType = DataType_DT_UNDEFINED;
  125. switch (dlTensor.dtype.code) {
  126. case kDLFloat:
  127. if (dlTensor.dtype.bits == 32) {
  128. dataType = DataType_DT_FLOAT;
  129. } else if (dlTensor.dtype.bits == 16) {
  130. dataType = DataType_DT_FLOAT16;
  131. }
  132. break;
  133. case kDLInt:
  134. if (dlTensor.dtype.bits == 32) {
  135. dataType = DataType_DT_INT32;
  136. } else if (dlTensor.dtype.bits == 8) {
  137. dataType = DataType_DT_INT8;
  138. }
  139. break;
  140. case kDLUInt:
  141. if (dlTensor.dtype.bits == 32) {
  142. dataType = DataType_DT_UINT32;
  143. } else if (dlTensor.dtype.bits == 8) {
  144. dataType = DataType_DT_UINT8;
  145. }
  146. break;
  147. default:
  148. break;
  149. }
  150. return dataType;
  151. }
  152. void Tensor::SetDataType(DataType dt) {
  153. switch (dt) {
  154. case DataType_DT_FLOAT:
  155. dlTensor.dtype.code = kDLFloat;
  156. dlTensor.dtype.bits = 32;
  157. dlTensor.dtype.lanes = 1;
  158. break;
  159. case DataType_DT_FLOAT16:
  160. dlTensor.dtype.code = kDLFloat;
  161. dlTensor.dtype.bits = 16;
  162. dlTensor.dtype.lanes = 1;
  163. break;
  164. case DataType_DT_INT8:
  165. dlTensor.dtype.code = kDLInt;
  166. dlTensor.dtype.bits = 8;
  167. dlTensor.dtype.lanes = 1;
  168. break;
  169. case DataType_DT_UINT8:
  170. dlTensor.dtype.code = kDLUInt;
  171. dlTensor.dtype.bits = 8;
  172. dlTensor.dtype.lanes = 1;
  173. break;
  174. case DataType_DT_INT32:
  175. dlTensor.dtype.code = kDLInt;
  176. dlTensor.dtype.bits = 32;
  177. dlTensor.dtype.lanes = 1;
  178. break;
  179. case DataType_DT_UINT32:
  180. dlTensor.dtype.code = kDLUInt;
  181. dlTensor.dtype.bits = 32;
  182. dlTensor.dtype.lanes = 1;
  183. break;
  184. default:
  185. MS_LOGW(" DataType %d is not implemented.", dt);
  186. MS_LOGW(" DataType DT_FLOAT is used.");
  187. dlTensor.dtype.code = kDLFloat;
  188. dlTensor.dtype.bits = 32;
  189. dlTensor.dtype.lanes = 1;
  190. return;
  191. }
  192. }
  193. int Tensor::GetNDim() const { return dlTensor.ndim; }
  194. std::vector<int64_t> Tensor::GetDims() const {
  195. std::vector<int64_t> dims;
  196. for (int i = 0; i < dlTensor.ndim; i++) {
  197. dims.push_back(dlTensor.shape[i]);
  198. }
  199. return dims;
  200. }
  201. size_t Tensor::GetElementSize() const {
  202. const int tile = 4;
  203. if (format == Format_NC4HW4) {
  204. size_t size = 1;
  205. for (int i = 0; i < dlTensor.ndim; i++) {
  206. auto var = static_cast<size_t>(dlTensor.shape[i]);
  207. if (i == 1) {
  208. var = UP_DIV(var, tile) * tile;
  209. }
  210. size *= var;
  211. }
  212. return size;
  213. } else {
  214. size_t size = 1;
  215. for (int i = 0; i < dlTensor.ndim; i++) {
  216. size *= static_cast<size_t>(dlTensor.shape[i]);
  217. }
  218. return size;
  219. }
  220. }
  221. size_t Tensor::GetDataSize() const {
  222. size_t size = GetElementSize();
  223. const int BYTES = 8;
  224. const int GAP = 7;
  225. size *= (dlTensor.dtype.bits * dlTensor.dtype.lanes + GAP) / BYTES;
  226. return size;
  227. }
  228. int Tensor::MallocData(std::shared_ptr<Allocator> allocator, int refCount) {
  229. if (dlTensor.data != nullptr) {
  230. this->refCount += refCount;
  231. return RET_OK;
  232. }
  233. this->refCount = refCount;
  234. size_t size = GetDataSize();
  235. if (allocator) {
  236. this->allocator = allocator;
  237. dlTensor.data = allocator->Malloc(size);
  238. } else {
  239. if (size > MAX_MALLOC_SIZE) {
  240. return RET_ERROR;
  241. }
  242. dlTensor.data = malloc(size);
  243. }
  244. if (dlTensor.data == nullptr) {
  245. return RET_ERROR;
  246. }
  247. return RET_OK;
  248. }
  249. void Tensor::ForceFreeData() {
  250. if (allocator) {
  251. allocator->Free(dlTensor.data);
  252. } else {
  253. free(dlTensor.data);
  254. }
  255. dlTensor.data = nullptr;
  256. }
  257. void Tensor::FreeData() {
  258. --refCount;
  259. if (refCount <= 0) {
  260. ForceFreeData();
  261. }
  262. }
  263. bool Tensor::CompareShape(const Tensor &dst) {
  264. if (dlTensor.ndim != dst.dlTensor.ndim || dlTensor.shape == nullptr || dst.dlTensor.shape == nullptr) {
  265. MS_LOGE("param error, one.ndim: %d, other.ndim: %d, one shape %p,other shape %p", dlTensor.ndim, dst.dlTensor.ndim,
  266. dlTensor.shape, dst.dlTensor.shape);
  267. return false;
  268. }
  269. for (int i = 0; i < dlTensor.ndim; i++) {
  270. if (dlTensor.shape[i] != dst.dlTensor.shape[i]) {
  271. MS_LOGE("one.shape[%d]: %ld, other.shape[%d]: %ld", i, dlTensor.shape[i], i, dst.dlTensor.shape[i]);
  272. return false;
  273. }
  274. }
  275. return true;
  276. }
  277. bool Tensor::CompareShape(const std::vector<int64_t> &other) {
  278. if (dlTensor.ndim != other.size() || dlTensor.shape == nullptr) {
  279. return false;
  280. }
  281. for (int i = 0; i < dlTensor.ndim; i++) {
  282. if (dlTensor.shape[i] != other[i]) {
  283. return false;
  284. }
  285. }
  286. return true;
  287. }
  288. int64_t Tensor::Height() const {
  289. if (dlTensor.shape == nullptr) {
  290. MS_LOGE("shape is null");
  291. }
  292. if (dlTensor.ndim != DIM_DEFAULT_SIZE) {
  293. MS_LOGE("Tensor should be 4 dimensional.");
  294. return -1;
  295. }
  296. switch (this->format) {
  297. case Format_NCHW:
  298. case Format_NC4HW4:
  299. return dlTensor.shape[NCHW_H];
  300. case Format_NHWC:
  301. return dlTensor.shape[NHWC_H];
  302. default:
  303. MS_LOGE("Unsupported format: %d", this->format);
  304. return -1;
  305. }
  306. }
  307. int64_t Tensor::Width() const {
  308. if (dlTensor.shape == nullptr) {
  309. MS_LOGE("shape is null");
  310. }
  311. if (dlTensor.ndim != DIM_DEFAULT_SIZE) {
  312. MS_LOGE("Tensor should be 4 dimensional.");
  313. return -1;
  314. }
  315. switch (this->format) {
  316. case Format_NCHW:
  317. case Format_NC4HW4:
  318. return dlTensor.shape[NCHW_W];
  319. case Format_NHWC:
  320. return dlTensor.shape[NHWC_W];
  321. default:
  322. MS_LOGE("Unsupported format: %d", this->format);
  323. return -1;
  324. }
  325. }
  326. int64_t Tensor::Channel() const {
  327. if (dlTensor.shape == nullptr) {
  328. MS_LOGE("shape is null");
  329. }
  330. if (dlTensor.ndim != DIM_DEFAULT_SIZE) {
  331. MS_LOGE("Tensor should be 4 dimensional.");
  332. return -1;
  333. }
  334. switch (this->format) {
  335. case Format_NCHW:
  336. case Format_NC4HW4:
  337. return dlTensor.shape[NCHW_C];
  338. case Format_NHWC:
  339. return dlTensor.shape[NHWC_C];
  340. default:
  341. MS_LOGE("Unsupported format: %d", this->format);
  342. return -1;
  343. }
  344. }
  345. int64_t Tensor::Batch() const {
  346. if (dlTensor.shape == nullptr) {
  347. MS_LOGE("shape is null");
  348. }
  349. if (dlTensor.ndim != DIM_DEFAULT_SIZE) {
  350. MS_LOGE("Tensor should be 4 dimensional.");
  351. return -1;
  352. }
  353. switch (this->format) {
  354. case Format_NCHW:
  355. case Format_NC4HW4:
  356. case Format_NHWC:
  357. return dlTensor.shape[NCHW_N];
  358. default:
  359. MS_LOGE("Unsupported format: %d", this->format);
  360. return -1;
  361. }
  362. }
  363. int64_t Tensor::Stride(int index) const {
  364. if (dlTensor.strides) {
  365. return dlTensor.strides[index];
  366. }
  367. if (dlTensor.shape == nullptr) {
  368. MS_LOGE("shape is null");
  369. return -1;
  370. }
  371. int64_t stride = 1;
  372. for (int i = index + 1; i < dlTensor.ndim; i++) {
  373. stride *= dlTensor.shape[i];
  374. }
  375. return stride;
  376. }
  377. void Tensor::SetStride() {
  378. if (dlTensor.strides == nullptr) {
  379. if (dlTensor.ndim < 1) {
  380. MS_LOGE("dims of dlTensor is empty.");
  381. return;
  382. }
  383. dlTensor.strides = new (std::nothrow) int64_t[dlTensor.ndim - 1];
  384. if (dlTensor.strides == nullptr) {
  385. MS_LOGW("new stride fail, ndim %d.", dlTensor.ndim);
  386. return;
  387. }
  388. }
  389. for (int idx = 0; idx < dlTensor.ndim - 1; idx++) {
  390. int64_t stride = 1;
  391. if (dlTensor.ndim <= idx + 1) {
  392. MS_LOGE("out of for loop upper limit.");
  393. return;
  394. }
  395. for (int i = idx + 1; i < dlTensor.ndim; i++) {
  396. stride *= dlTensor.shape[i];
  397. }
  398. dlTensor.strides[idx] = stride;
  399. }
  400. }
  401. void Tensor::SetScale(bool isScale) { this->isScale = isScale; }
  402. void Tensor::SetStride(int index, int64_t stride) {
  403. if (index >= dlTensor.ndim) {
  404. return;
  405. }
  406. if (dlTensor.strides == nullptr) {
  407. SetStride();
  408. }
  409. dlTensor.strides[index] = stride;
  410. return;
  411. }
  412. void Tensor::SetDims(const std::vector<int64_t> &dims) {
  413. if (dlTensor.shape != nullptr) {
  414. delete[] dlTensor.shape;
  415. }
  416. dlTensor.ndim = static_cast<int>(dims.size());
  417. if (dlTensor.ndim > 0) {
  418. dlTensor.shape = new (std::nothrow) int64_t[dlTensor.ndim];
  419. if (dlTensor.shape != nullptr) {
  420. for (int i = 0; i < dlTensor.ndim; i++) {
  421. dlTensor.shape[i] = dims[i];
  422. }
  423. } else {
  424. MS_LOGW("new shape fail,ndim %d", dlTensor.ndim);
  425. }
  426. } else {
  427. dlTensor.shape = nullptr;
  428. }
  429. }
  430. void Tensor::FreeTensor() {
  431. if (dlTensor.shape != nullptr) {
  432. delete[] dlTensor.shape;
  433. dlTensor.shape = nullptr;
  434. }
  435. if (dlTensor.strides != nullptr) {
  436. delete[] dlTensor.strides;
  437. dlTensor.strides = nullptr;
  438. }
  439. dlTensor.ndim = 0;
  440. if (allocator != nullptr) {
  441. allocator->Free(dlTensor.data);
  442. } else {
  443. free(dlTensor.data);
  444. }
  445. dlTensor.data = nullptr;
  446. }
  447. size_t Tensor::GetNC4HW4ElementSize(bool isNhwc) {
  448. int alignIndex = 1;
  449. if (isNhwc) {
  450. alignIndex = 3;
  451. }
  452. size_t size = 1;
  453. for (int i = 0; i < dlTensor.ndim; i++) {
  454. auto var = static_cast<size_t>(dlTensor.shape[i]);
  455. if (i == alignIndex) {
  456. var = ALIGN_UP4(var);
  457. }
  458. size *= var;
  459. }
  460. return size;
  461. }
  462. size_t Tensor::GetNC4HW4DataSize(bool isNhwc) {
  463. size_t size = GetNC4HW4ElementSize(isNhwc);
  464. const int BYTES = 8;
  465. const int GAP = 7;
  466. size *= (dlTensor.dtype.bits * dlTensor.dtype.lanes + GAP) / BYTES;
  467. return size;
  468. }
  469. } // namespace predict
  470. } // namespace mindspore