You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel2ms.cc 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "predict/converter/kernel2ms.h"
  17. #include <algorithm>
  18. #include "ir/anf.h"
  19. #include "predict/converter/lite_model/op_attr_packer.h"
  20. #include "mindspore/ccsrc/operator/ops.h"
  21. namespace mindspore {
  22. namespace executor {
  23. Kernel2Ms &Kernel2Ms::GetInstance() {
  24. static Kernel2Ms instance;
  25. return instance;
  26. }
  27. bool Kernel2Ms::SetMemResue() const {
  28. MS_LOG(INFO) << "MemResue start";
  29. return true;
  30. }
  31. bool Kernel2Ms::SetAllTensors(const TensorCachePtr &tensor_cache, SubGraphDefT *ms_graph) {
  32. if (tensor_cache == nullptr || ms_graph == nullptr) {
  33. return false;
  34. }
  35. const std::unordered_map<int, std::vector<ExTensorPtr>> &cachedTensors = tensor_cache->GetCachedTensor();
  36. size_t total_size = 0;
  37. if (cachedTensors.empty()) {
  38. return false;
  39. }
  40. for (auto &iter : cachedTensors) {
  41. auto ex_tensors = iter.second;
  42. total_size += ex_tensors.size();
  43. }
  44. ms_graph->allTensors.resize(total_size);
  45. for (auto &iter : cachedTensors) {
  46. for (auto &ex_tensor : iter.second) {
  47. std::unique_ptr<TensorDefT> ms_tensor(new TensorDefT());
  48. auto device_tensor_tmp = ex_tensor->device_tensor_ptr_;
  49. auto device_d_type = device_tensor_tmp->data_type();
  50. ms_tensor->dataType = predict::utils::GetMSDataType(device_d_type);
  51. auto device_shape = device_tensor_tmp->shape();
  52. ms_tensor->dims.clear();
  53. if (device_shape.empty()) {
  54. ms_tensor->dims.push_back(1);
  55. } else {
  56. ms_tensor->dims.assign(device_shape.begin(), device_shape.end());
  57. }
  58. std::string format_str = device_tensor_tmp->device_info().format_;
  59. ms_tensor->format = predict::utils::GetMsFormat(format_str);
  60. ms_tensor->offset = 0;
  61. auto stable = ex_tensor->stable_;
  62. if (stable == INPUTDATA || stable == CONSTANT || stable == WEIGHTS) {
  63. ms_tensor->refCount = MS_MAX_REFCOUNT;
  64. } else {
  65. ms_tensor->refCount = 0;
  66. }
  67. ms_graph->allTensors[IntToSize(ex_tensor->index_)] = std::move(ms_tensor);
  68. }
  69. }
  70. return true;
  71. }
  72. bool Kernel2Ms::SetGraphOutputIdx(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache,
  73. SubGraphDefT *ms_graph, AllOutputTensors *all_output_tensors) {
  74. MS_EXCEPTION_IF_NULL(tensor_cache);
  75. MS_EXCEPTION_IF_NULL(ms_graph);
  76. MS_EXCEPTION_IF_NULL(all_output_tensors);
  77. auto out_nodes = kernel_graph_ptr->outputs();
  78. if (out_nodes.empty()) {
  79. return false;
  80. }
  81. // maybe need to judge out_nodes is real && output must be CNode
  82. for (size_t i = 0; i < out_nodes.size(); ++i) {
  83. std::vector<AnfNodePtr> real_inputs_link;
  84. std::vector<size_t> real_output_idx_link;
  85. GetRealInpoutsPtr(out_nodes[i], &real_inputs_link, &real_output_idx_link);
  86. if (real_inputs_link.empty()) {
  87. MS_LOG(INFO) << "this graph output node is vitural node, has no real input";
  88. continue;
  89. }
  90. for (size_t k = 0; k < real_inputs_link.size(); ++k) {
  91. int key = node_indexs_[out_nodes[i].get()];
  92. auto ex_tensor_list = tensor_cache->findTensor(key);
  93. if (ex_tensor_list.empty()) {
  94. MS_LOG(INFO) << "SetGraphOutputIdx do not add Extensor ";
  95. continue;
  96. }
  97. auto ex_tensor = ex_tensor_list[real_output_idx_link[k]];
  98. ex_tensor_list.clear();
  99. ms_graph->outputIndex.push_back(ex_tensor->index_);
  100. }
  101. }
  102. return true;
  103. }
  104. bool Kernel2Ms::SetOpOutputIdx(const CNodePtr &c_node_ptr, const TensorPtr &output_tensor,
  105. const TensorCachePtr &tensor_cache, int ref_count, size_t order_index,
  106. NodeDef *ms_node) {
  107. MS_EXCEPTION_IF_NULL(c_node_ptr);
  108. MS_EXCEPTION_IF_NULL(output_tensor);
  109. MS_EXCEPTION_IF_NULL(ms_node);
  110. MS_EXCEPTION_IF_NULL(tensor_cache);
  111. if (!predict::utils::FindNodeInMap(node_indexs_, c_node_ptr)) {
  112. MS_LOG(ERROR) << "can not find any pk_key in inited node_indexs map";
  113. return false;
  114. }
  115. int tensor_key = node_indexs_[c_node_ptr.get()];
  116. auto host_shape = AnfAlgo::GetOutputInferShape(c_node_ptr, order_index);
  117. std::vector<int> tensor_shape;
  118. (void)std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(tensor_shape), SizeToInt);
  119. int outputIndex = tensor_cache->addExTensor(tensor_key, output_tensor, ref_count, tensor_shape, KERNEL);
  120. ms_node->opDef->outputIndex.push_back(outputIndex);
  121. return true;
  122. }
  123. void Kernel2Ms::GetRealInpoutsPtr(const AnfNodePtr &node, std::vector<AnfNodePtr> *real_inputs,
  124. std::vector<size_t> *real_output_idx) {
  125. MS_EXCEPTION_IF_NULL(real_inputs);
  126. MS_EXCEPTION_IF_NULL(real_output_idx);
  127. size_t default_idx = 0;
  128. if (node->isa<CNode>()) {
  129. auto c_node = node->cast<CNodePtr>();
  130. MS_EXCEPTION_IF_NULL(c_node);
  131. std::string c_node_name = GetCNodeFuncName(c_node);
  132. if (c_node_name == prim::kPrimTupleGetItem->name()) {
  133. auto v_node = c_node->inputs()[kTupleGetItemIndex]->cast<ValueNodePtr>();
  134. MS_EXCEPTION_IF_NULL(v_node);
  135. default_idx = IntToSize(GetValue<int>(v_node->value()));
  136. real_inputs->push_back(c_node->inputs()[1]);
  137. real_output_idx->push_back(default_idx);
  138. return;
  139. } else if (c_node_name == prim::kPrimDepend->name()) {
  140. GetRealInpoutsPtr(c_node->inputs()[1], real_inputs, real_output_idx);
  141. return;
  142. } else if (c_node_name == prim::kPrimMakeTuple->name()) {
  143. for (auto &in : c_node->inputs()) {
  144. GetRealInpoutsPtr(in, real_inputs, real_output_idx);
  145. }
  146. return;
  147. } else {
  148. real_inputs->push_back(node);
  149. real_output_idx->push_back(default_idx);
  150. }
  151. } else if (node->isa<Parameter>()) {
  152. real_inputs->push_back(node);
  153. real_output_idx->push_back(default_idx);
  154. } else if (node->isa<ValueNode>()) {
  155. real_inputs->push_back(node);
  156. real_output_idx->push_back(default_idx);
  157. }
  158. }
  159. bool Kernel2Ms::SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, NodeDef *ms_node) {
  160. MS_EXCEPTION_IF_NULL(c_node_ptr);
  161. MS_EXCEPTION_IF_NULL(tensor_cache);
  162. MS_EXCEPTION_IF_NULL(ms_node);
  163. for (size_t i = 1; i < c_node_ptr->inputs().size(); ++i) {
  164. std::vector<AnfNodePtr> real_inputs;
  165. std::vector<size_t> real_output_idx;
  166. GetRealInpoutsPtr(c_node_ptr->inputs()[i], &real_inputs, &real_output_idx);
  167. if (real_inputs.empty()) {
  168. MS_LOG(INFO) << "kernel has no inputs: " << c_node_ptr.get() << " input size[%lu]" << c_node_ptr->inputs().size();
  169. continue;
  170. }
  171. for (size_t j = 0; j < real_inputs.size(); ++j) {
  172. int key = node_indexs_[real_inputs[j].get()];
  173. std::vector<ExTensorPtr> ex_tensor_list = tensor_cache->findTensor(key);
  174. if (ex_tensor_list.empty()) {
  175. continue;
  176. }
  177. ExTensorPtr ex_tensor_ptr = ex_tensor_list[real_output_idx[j]];
  178. ex_tensor_list.clear();
  179. ms_node->opDef->inputIndex.push_back(ex_tensor_ptr->index_);
  180. }
  181. }
  182. return true;
  183. }
  184. void Kernel2Ms::TransformGraphIndx() {
  185. // transform index && anfnodeptr
  186. if (node_indexs_.empty()) {
  187. MS_LOG(EXCEPTION) << "node_indexs_ not ininted";
  188. }
  189. for (auto &item : node_indexs_) {
  190. index_nodes_[item.second] = item.first;
  191. }
  192. }
  193. bool Kernel2Ms::InitGraphInputsIndx(const KernelGraphPtr &kernel_graph_ptr) {
  194. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  195. auto input_nodes = kernel_graph_ptr->inputs();
  196. if (input_nodes.empty()) {
  197. return false;
  198. }
  199. for (const auto &input_node : input_nodes) {
  200. if (input_node->isa<Parameter>()) {
  201. if (!predict::utils::FindNodeInMap(node_indexs_, input_node)) {
  202. // init every parameter node
  203. node_indexs_[input_node.get()] = graph_index_;
  204. graph_index_++;
  205. }
  206. } else {
  207. MS_LOG(INFO) << "This node is anfnode, no need to handle, continue. node info: " << input_node->ToString();
  208. continue;
  209. }
  210. }
  211. MS_LOG(DEBUG) << "inputs GraphIndex: " << graph_index_;
  212. return true;
  213. }
  214. bool Kernel2Ms::InitGraphValueNodesIndx(const KernelGraphPtr &kernel_graph_ptr) {
  215. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  216. if (kernel_graph_ptr->value_nodes().empty()) {
  217. return false;
  218. }
  219. for (auto &item : kernel_graph_ptr->value_nodes()) {
  220. if (item.first->isa<ValueNode>()) {
  221. auto value_node = item.first->cast<ValueNodePtr>();
  222. MS_EXCEPTION_IF_NULL(value_node);
  223. if (value_node == nullptr) {
  224. MS_LOG(WARNING) << "value_node is nullptr";
  225. return false;
  226. }
  227. if (value_node->value() == nullptr) {
  228. MS_LOG(ERROR) << "Constant value is null.";
  229. return false;
  230. }
  231. if (!value_node->value()->isa<tensor::Tensor>()) {
  232. continue;
  233. }
  234. if (!predict::utils::FindNodeInMap(node_indexs_, item.first)) {
  235. // init node
  236. auto node_ptr = item.first;
  237. node_indexs_[node_ptr.get()] = graph_index_;
  238. graph_index_++;
  239. }
  240. }
  241. }
  242. return true;
  243. }
  244. bool Kernel2Ms::InitGraphOpsIndx(const KernelGraphPtr &kernel_graph_ptr) {
  245. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  246. auto kernels = kernel_graph_ptr->execution_order();
  247. if (kernels.empty()) {
  248. MS_LOG(WARNING) << "this graph has no kernel";
  249. return false;
  250. }
  251. for (size_t i = 0; i < kernels.size(); ++i) {
  252. // for each kernel's inputs foreach real_input
  253. if (kernels[i]->isa<CNode>()) {
  254. if (!predict::utils::FindNodeInMap(node_indexs_, kernels[i])) {
  255. // init node
  256. node_indexs_[kernels[i].get()] = graph_index_;
  257. graph_index_++;
  258. }
  259. }
  260. }
  261. return true;
  262. }
  263. bool Kernel2Ms::InitGraphOutputsIndx(const KernelGraphPtr &kernel_graph_ptr) {
  264. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  265. // graph output && their inputs should link together
  266. auto out_nodes = kernel_graph_ptr->outputs();
  267. if (out_nodes.empty()) {
  268. MS_LOG(ERROR) << "this graph has no outputs";
  269. return false;
  270. }
  271. for (auto &item : out_nodes) {
  272. if (!predict::utils::FindNodeInMap(node_indexs_, item)) {
  273. node_indexs_[item.get()] = graph_index_;
  274. graph_index_++;
  275. }
  276. }
  277. return true;
  278. }
  279. bool Kernel2Ms::InitGraphIndx(const KernelGraphPtr &kernel_graph_ptr) {
  280. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  281. // only parameter
  282. if (!InitGraphInputsIndx(kernel_graph_ptr)) {
  283. return false;
  284. }
  285. // init value node
  286. if (!InitGraphValueNodesIndx(kernel_graph_ptr)) {
  287. return false;
  288. }
  289. // init op
  290. if (!InitGraphOpsIndx(kernel_graph_ptr)) {
  291. return false;
  292. }
  293. // init Graphoutput attention: out_put nodes have inputs
  294. return InitGraphOutputsIndx(kernel_graph_ptr);
  295. }
  296. bool Kernel2Ms::SetGraphInputTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache,
  297. SubGraphDefT *ms_graph) {
  298. MS_EXCEPTION_IF_NULL(tensor_cache);
  299. MS_EXCEPTION_IF_NULL(ms_graph);
  300. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  301. if (convert_mode_ == kConvertUnused) {
  302. return false;
  303. }
  304. if (kernel_graph_ptr->inputs().empty()) {
  305. return false;
  306. }
  307. for (const auto &input_node : kernel_graph_ptr->inputs()) {
  308. if (input_node->isa<Parameter>()) {
  309. ParameterPtr pk_node = std::dynamic_pointer_cast<Parameter>(input_node);
  310. TensorPtr device_tensor;
  311. if (convert_mode_ == kConvertCpuMode) {
  312. device_tensor = predict::utils::GetParaCpuTensor(input_node);
  313. } else {
  314. device_tensor = predict::utils::GetParaAscendTensor(input_node);
  315. }
  316. if (device_tensor == nullptr) {
  317. return false;
  318. }
  319. ExTensorType node_type;
  320. if (AnfAlgo::IsParameterWeight(pk_node)) {
  321. node_type = WEIGHTS;
  322. } else {
  323. node_type = INPUTDATA;
  324. }
  325. if (!predict::utils::FindNodeInMap(node_indexs_, input_node)) {
  326. MS_LOG(WARNING) << "can not find any pk_key in inited node_indexs map";
  327. return false;
  328. }
  329. auto pk_key = node_indexs_[input_node.get()];
  330. all_output_tensors_[pk_key].push_back(device_tensor);
  331. int nodeRefCount = SizeToInt(AnfAlgo::GetOutputTensorNum(input_node));
  332. int nodeInputIdx =
  333. tensor_cache->addExTensor(pk_key, device_tensor, nodeRefCount, device_tensor->shape(), node_type);
  334. if (!AnfAlgo::IsParameterWeight(pk_node)) {
  335. ms_graph->inputIndex.push_back(nodeInputIdx);
  336. all_input_idxs_.push_back(nodeInputIdx);
  337. } else {
  338. input_weight_idxs_.push_back(nodeInputIdx);
  339. all_input_idxs_.push_back(nodeInputIdx);
  340. }
  341. }
  342. }
  343. return true;
  344. }
  345. bool Kernel2Ms::SetGraphValueTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache) {
  346. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  347. MS_EXCEPTION_IF_NULL(tensor_cache);
  348. for (auto &item : kernel_graph_ptr->value_nodes()) {
  349. if (item.first->isa<ValueNode>()) {
  350. auto const_node = item.first->cast<ValueNodePtr>();
  351. auto tensor_constant = predict::utils::GetValueTensor(const_node);
  352. if (tensor_constant == nullptr) {
  353. continue;
  354. }
  355. if (!predict::utils::FindNodeInMap(node_indexs_, item.first)) {
  356. MS_LOG(WARNING) << "can not find any pk_key in inited node_indexs map";
  357. return false;
  358. }
  359. int constant_key = node_indexs_[(item.first).get()];
  360. all_output_tensors_[constant_key].push_back(tensor_constant);
  361. auto shape = tensor_constant->shape();
  362. (void)tensor_cache->addExTensor(constant_key, tensor_constant, 0, shape, CONSTANT);
  363. }
  364. }
  365. return true;
  366. }
  367. bool Kernel2Ms::SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache,
  368. SubGraphDefT *ms_graph) {
  369. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  370. MS_EXCEPTION_IF_NULL(tensor_cache);
  371. MS_EXCEPTION_IF_NULL(ms_graph);
  372. auto kernels = kernel_graph_ptr->execution_order();
  373. if (kernels.empty()) {
  374. MS_LOG(ERROR) << "this graph has no kernels";
  375. return false;
  376. }
  377. for (auto &kernel : kernels) {
  378. if (!predict::utils::FindNodeInMap(node_indexs_, kernel)) {
  379. MS_LOG(ERROR) << "can not find any pk_key in inited node_indexs map";
  380. return false;
  381. }
  382. auto kernel_key = node_indexs_[kernel.get()];
  383. std::unique_ptr<NodeDef> ms_node(new NodeDef);
  384. ms_node->fmkType = mindspore::predict::FmkType_CAFFE;
  385. std::unique_ptr<OpDefT> ms_op(new OpDefT());
  386. auto c_name = AnfAlgo::GetCNodeName(kernel);
  387. auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name);
  388. if (fun == nullptr) {
  389. MS_LOG(ERROR) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
  390. return false;
  391. } else if (!fun(kernel, ms_op.get())) {
  392. MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed.";
  393. return false;
  394. }
  395. ms_node->opDef = std::move(ms_op);
  396. auto output_size = AnfAlgo::GetOutputTensorNum(kernel);
  397. int nodeRefCount = SizeToInt(output_size);
  398. for (size_t j = 0; j < output_size; ++j) {
  399. TensorPtr device_tensor;
  400. if (convert_mode_ == kConvertCpuMode) {
  401. device_tensor = predict::utils::GetKernelCpuTensor(kernel, j);
  402. } else if (convert_mode_ == kConvertAscendMode) {
  403. device_tensor = predict::utils::GetKernelAscendTensor(kernel, j);
  404. }
  405. if (device_tensor == nullptr) {
  406. return false;
  407. }
  408. all_output_tensors_[kernel_key].push_back(device_tensor);
  409. if (!SetOpOutputIdx(kernel, device_tensor, tensor_cache, nodeRefCount, j, ms_node.get())) {
  410. return false;
  411. }
  412. }
  413. tmp_op_nodes_.emplace_back(ms_node.release());
  414. }
  415. return true;
  416. }
  417. bool Kernel2Ms::KernelGraph2MsGraph(const KernelGraphPtr &kernel_graph_ptr) {
  418. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  419. graph_index_ = 0;
  420. all_output_tensors_.clear();
  421. node_indexs_.clear();
  422. index_nodes_.clear();
  423. std::unique_ptr<SubGraphDefT> sub_ms_graph(new SubGraphDefT());
  424. if (!InitGraphIndx(kernel_graph_ptr)) {
  425. return false;
  426. }
  427. TransformGraphIndx();
  428. tensor_cache_ptr_ = std::make_shared<TensorCache>();
  429. // foreach node to init it's real output tensor
  430. if (!SetGraphInputTensors(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get())) {
  431. return false;
  432. }
  433. // Get KernelGraph value node
  434. if (!SetGraphValueTensors(kernel_graph_ptr, tensor_cache_ptr_)) {
  435. return false;
  436. }
  437. // Get KernelGraph apply_kernel && add opNode
  438. if (!SetGraphOpTensors(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get())) {
  439. return false;
  440. }
  441. // Get KernelGraph outputs
  442. if (!SetGraphOutputIdx(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get(), &all_output_tensors_)) {
  443. return false;
  444. }
  445. auto kernels = kernel_graph_ptr->execution_order();
  446. for (size_t i = 0; i < kernels.size(); ++i) {
  447. auto ms_node = tmp_op_nodes_[i];
  448. if (!SetOpInputIdx(kernels[i], tensor_cache_ptr_, ms_node)) {
  449. return false;
  450. }
  451. std::unique_ptr<NodeDef> ms_node_tmp(ms_node);
  452. sub_ms_graph->nodes.emplace_back(std::move(ms_node_tmp));
  453. }
  454. if (!SetAllTensors(tensor_cache_ptr_, sub_ms_graph.get())) {
  455. return false;
  456. }
  457. if (!SetMemResue()) {
  458. return false;
  459. }
  460. sub_ms_graph_ = std::move(sub_ms_graph);
  461. sub_ms_graph_->name = "default_sub_graph";
  462. return true;
  463. }
  464. bool Kernel2Ms::CheckInputSizes(const std::vector<TensorPtr> &input_tensors,
  465. const std::vector<uint32_t> &all_input_idxs) {
  466. if (input_tensors.size() != all_input_idxs.size()) {
  467. MS_LOG(EXCEPTION) << "real input tensors size:" << input_tensors.size()
  468. << "not equal converted tesnors size:" << all_input_idxs.size() << "the graph has changed";
  469. }
  470. for (auto in : all_input_idxs) {
  471. if (in < sub_ms_graph_->allTensors.size()) {
  472. auto real_tensor = input_tensors[in];
  473. auto convert_dims = sub_ms_graph_->allTensors[in]->dims;
  474. auto real_dims = real_tensor->shape();
  475. if (real_dims.size() != convert_dims.size()) {
  476. return false;
  477. } else {
  478. for (size_t i = 0; i < convert_dims.size(); ++i) {
  479. if (convert_dims[i] != real_dims[i]) {
  480. return false;
  481. }
  482. }
  483. }
  484. } else {
  485. MS_LOG(EXCEPTION) << "index: " << in << "in all_input_idxs is valid";
  486. }
  487. }
  488. return true;
  489. }
  490. void Kernel2Ms::ReleaseContextRes() {
  491. tmp_op_nodes_.clear();
  492. node_indexs_.clear();
  493. index_nodes_.clear();
  494. tensor_cache_ptr_ = nullptr;
  495. all_output_tensors_.clear();
  496. }
  497. bool Kernel2Ms::KernelInput2MS(const std::vector<TensorPtr> &input_tensors) {
  498. const std::unordered_map<int, std::vector<ExTensorPtr>> &cache_tensors = tensor_cache_ptr_->GetCachedTensor();
  499. if (cache_tensors.empty()) {
  500. return false;
  501. }
  502. auto all_weights_idxs = GetAllInputWeightIdxs();
  503. auto all_input_idxs = GetAllInputIdxs();
  504. auto real_input_size = input_tensors.size();
  505. // check tensor size
  506. bool ret = CheckInputSizes(input_tensors, all_input_idxs);
  507. std::vector<uint32_t> match_to_rel_idxs;
  508. // indx order not matched,macth to it
  509. if (!ret) {
  510. for (auto idx : all_weights_idxs) {
  511. auto macth_idx = real_input_size - idx;
  512. match_to_rel_idxs.push_back(macth_idx);
  513. }
  514. } else {
  515. match_to_rel_idxs = all_weights_idxs;
  516. }
  517. if (match_to_rel_idxs.size() == all_weights_idxs.size()) {
  518. for (size_t j = 0; j < all_weights_idxs.size(); ++j) {
  519. auto cache_idx = all_weights_idxs[j];
  520. auto match_idx = match_to_rel_idxs[j];
  521. auto real_tensor = input_tensors[match_idx];
  522. auto real_size = LongToSize(real_tensor->data().nbytes());
  523. auto real_data = real_tensor->data_c(false);
  524. MS_EXCEPTION_IF_NULL(real_data);
  525. if (sub_ms_graph_->allTensors[cache_idx] != nullptr) {
  526. sub_ms_graph_->allTensors[cache_idx]->data.resize(real_size);
  527. }
  528. if (memcpy_s(sub_ms_graph_->allTensors[cache_idx]->data.data(), real_size, real_data, real_size) != 0) {
  529. MS_LOG(ERROR) << "KernelInput2MS memcpy_s failed";
  530. return false;
  531. }
  532. }
  533. }
  534. ReleaseContextRes();
  535. return true;
  536. }
  537. bool Kernel2Ms::SaveDeviceModel(const std::shared_ptr<GraphDefT> &new_ms_graph_ptr, const std::string &save_path_name) {
  538. MS_EXCEPTION_IF_NULL(new_ms_graph_ptr);
  539. return predict::utils::SaveDeviceModelUtil(new_ms_graph_ptr, save_path_name, sub_ms_graph_.release());
  540. }
  541. } // namespace executor
  542. } // namespace mindspore