You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel2ms.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "predict/converter/kernel2ms.h"
  17. #include <algorithm>
  18. #include "ir/anf.h"
  19. #include "predict/converter/lite_model/op_attr_packer.h"
  20. #include "mindspore/ccsrc/operator/ops.h"
  21. namespace mindspore {
  22. namespace executor {
  23. Kernel2Ms &Kernel2Ms::GetInstance() {
  24. static Kernel2Ms instance;
  25. return instance;
  26. }
  27. bool Kernel2Ms::SetMemResue() const {
  28. MS_LOG(INFO) << "MemResue start";
  29. return true;
  30. }
  31. bool Kernel2Ms::SetAllTensors(const TensorCachePtr &tensor_cache, SubGraphDefT *ms_graph) {
  32. if (tensor_cache == nullptr || ms_graph == nullptr) {
  33. return false;
  34. }
  35. const std::unordered_map<int, std::vector<ExTensorPtr>> &cachedTensors = tensor_cache->GetCachedTensor();
  36. size_t total_size = 0;
  37. if (cachedTensors.empty()) {
  38. return false;
  39. }
  40. for (auto &iter : cachedTensors) {
  41. auto ex_tensors = iter.second;
  42. total_size += ex_tensors.size();
  43. }
  44. ms_graph->allTensors.resize(total_size);
  45. for (auto &iter : cachedTensors) {
  46. for (auto &ex_tensor : iter.second) {
  47. std::unique_ptr<TensorDefT> ms_tensor(new TensorDefT());
  48. auto device_tensor_tmp = ex_tensor->device_tensor_ptr_;
  49. auto device_d_type = device_tensor_tmp->data_type();
  50. ms_tensor->dataType = predict::utils::GetMSDataType(device_d_type);
  51. auto device_shape = device_tensor_tmp->shape();
  52. ms_tensor->dims.clear();
  53. if (device_shape.empty()) {
  54. ms_tensor->dims.push_back(1);
  55. } else {
  56. ms_tensor->dims.assign(device_shape.begin(), device_shape.end());
  57. }
  58. std::string format_str = device_tensor_tmp->device_info().format_;
  59. ms_tensor->format = predict::utils::GetMsFormat(format_str);
  60. ms_tensor->offset = 0;
  61. auto stable = ex_tensor->stable_;
  62. if (stable == INPUTDATA || stable == CONSTANT || stable == WEIGHTS) {
  63. ms_tensor->refCount = MS_MAX_REFCOUNT;
  64. } else {
  65. ms_tensor->refCount = 0;
  66. }
  67. ms_graph->allTensors[IntToSize(ex_tensor->index_)] = std::move(ms_tensor);
  68. }
  69. }
  70. return true;
  71. }
  72. bool Kernel2Ms::SetGraphOutputIdx(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache,
  73. SubGraphDefT *ms_graph, AllOutputTensors *all_output_tensors) {
  74. MS_EXCEPTION_IF_NULL(tensor_cache);
  75. MS_EXCEPTION_IF_NULL(ms_graph);
  76. MS_EXCEPTION_IF_NULL(all_output_tensors);
  77. auto out_nodes = kernel_graph_ptr->outputs();
  78. if (out_nodes.empty()) {
  79. return false;
  80. }
  81. // maybe need to judge out_nodes is real && output must be CNode
  82. for (size_t i = 0; i < out_nodes.size(); ++i) {
  83. std::vector<AnfNodePtr> real_inputs_link;
  84. std::vector<size_t> real_output_idx_link;
  85. GetRealInpoutsPtr(out_nodes[i], &real_inputs_link, &real_output_idx_link);
  86. if (real_inputs_link.empty()) {
  87. MS_LOG(INFO) << "this graph output node is vitural node, has no real input";
  88. continue;
  89. }
  90. for (size_t k = 0; k < real_inputs_link.size(); ++k) {
  91. int key = node_indexs_[out_nodes[i].get()];
  92. auto ex_tensor_list = tensor_cache->findTensor(key);
  93. if (ex_tensor_list.empty()) {
  94. MS_LOG(INFO) << "SetGraphOutputIdx do not add Extensor ";
  95. continue;
  96. }
  97. auto ex_tensor = ex_tensor_list[real_output_idx_link[k]];
  98. ex_tensor_list.clear();
  99. ms_graph->outputIndex.push_back(ex_tensor->index_);
  100. }
  101. }
  102. return true;
  103. }
  104. bool Kernel2Ms::SetOpOutputIdx(const CNodePtr &c_node_ptr, const TensorPtr &output_tensor,
  105. const TensorCachePtr &tensor_cache, int ref_count, size_t order_index, OpDefT *ms_node) {
  106. MS_EXCEPTION_IF_NULL(c_node_ptr);
  107. MS_EXCEPTION_IF_NULL(output_tensor);
  108. MS_EXCEPTION_IF_NULL(ms_node);
  109. MS_EXCEPTION_IF_NULL(tensor_cache);
  110. if (!predict::utils::FindNodeInMap(node_indexs_, c_node_ptr)) {
  111. MS_LOG(ERROR) << "can not find any pk_key in inited node_indexs map";
  112. return false;
  113. }
  114. int tensor_key = node_indexs_[c_node_ptr.get()];
  115. auto host_shape = AnfAlgo::GetOutputInferShape(c_node_ptr, order_index);
  116. std::vector<int> tensor_shape;
  117. (void)std::transform(host_shape.begin(), host_shape.end(), std::back_inserter(tensor_shape), SizeToInt);
  118. int outputIndex = tensor_cache->addExTensor(tensor_key, output_tensor, ref_count, tensor_shape, KERNEL);
  119. ms_node->outputIndex.push_back(outputIndex);
  120. return true;
  121. }
  122. void Kernel2Ms::GetRealInpoutsPtr(const AnfNodePtr &node, std::vector<AnfNodePtr> *real_inputs,
  123. std::vector<size_t> *real_output_idx) {
  124. MS_EXCEPTION_IF_NULL(real_inputs);
  125. MS_EXCEPTION_IF_NULL(real_output_idx);
  126. size_t default_idx = 0;
  127. if (node->isa<CNode>()) {
  128. auto c_node = node->cast<CNodePtr>();
  129. MS_EXCEPTION_IF_NULL(c_node);
  130. std::string c_node_name = GetCNodeFuncName(c_node);
  131. if (c_node_name == prim::kPrimTupleGetItem->name()) {
  132. auto v_node = c_node->inputs()[kTupleGetItemIndex]->cast<ValueNodePtr>();
  133. MS_EXCEPTION_IF_NULL(v_node);
  134. default_idx = IntToSize(GetValue<int>(v_node->value()));
  135. real_inputs->push_back(c_node->inputs()[1]);
  136. real_output_idx->push_back(default_idx);
  137. return;
  138. } else if (c_node_name == prim::kPrimDepend->name()) {
  139. GetRealInpoutsPtr(c_node->inputs()[1], real_inputs, real_output_idx);
  140. return;
  141. } else if (c_node_name == prim::kPrimMakeTuple->name()) {
  142. for (auto &in : c_node->inputs()) {
  143. GetRealInpoutsPtr(in, real_inputs, real_output_idx);
  144. }
  145. return;
  146. } else {
  147. real_inputs->push_back(node);
  148. real_output_idx->push_back(default_idx);
  149. }
  150. } else if (node->isa<Parameter>()) {
  151. real_inputs->push_back(node);
  152. real_output_idx->push_back(default_idx);
  153. } else if (node->isa<ValueNode>()) {
  154. real_inputs->push_back(node);
  155. real_output_idx->push_back(default_idx);
  156. }
  157. }
  158. bool Kernel2Ms::SetOpInputIdx(const CNodePtr &c_node_ptr, const TensorCachePtr &tensor_cache, OpDefT *ms_node) {
  159. MS_EXCEPTION_IF_NULL(c_node_ptr);
  160. MS_EXCEPTION_IF_NULL(tensor_cache);
  161. MS_EXCEPTION_IF_NULL(ms_node);
  162. for (size_t i = 1; i < c_node_ptr->inputs().size(); ++i) {
  163. std::vector<AnfNodePtr> real_inputs;
  164. std::vector<size_t> real_output_idx;
  165. GetRealInpoutsPtr(c_node_ptr->inputs()[i], &real_inputs, &real_output_idx);
  166. if (real_inputs.empty()) {
  167. MS_LOG(INFO) << "kernel has no inputs: " << c_node_ptr.get() << " input size[%lu]" << c_node_ptr->inputs().size();
  168. continue;
  169. }
  170. for (size_t j = 0; j < real_inputs.size(); ++j) {
  171. int key = node_indexs_[real_inputs[j].get()];
  172. std::vector<ExTensorPtr> ex_tensor_list = tensor_cache->findTensor(key);
  173. if (ex_tensor_list.empty()) {
  174. continue;
  175. }
  176. ExTensorPtr ex_tensor_ptr = ex_tensor_list[real_output_idx[j]];
  177. ex_tensor_list.clear();
  178. ms_node->inputIndex.push_back(ex_tensor_ptr->index_);
  179. }
  180. }
  181. return true;
  182. }
  183. void Kernel2Ms::TransformGraphIndx() {
  184. // transform index && anfnodeptr
  185. if (node_indexs_.empty()) {
  186. MS_LOG(EXCEPTION) << "node_indexs_ not ininted";
  187. }
  188. for (auto &item : node_indexs_) {
  189. index_nodes_[item.second] = item.first;
  190. }
  191. }
  192. bool Kernel2Ms::InitGraphInputsIndx(const KernelGraphPtr &kernel_graph_ptr) {
  193. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  194. auto input_nodes = kernel_graph_ptr->inputs();
  195. if (input_nodes.empty()) {
  196. return false;
  197. }
  198. for (const auto &input_node : input_nodes) {
  199. if (input_node->isa<Parameter>()) {
  200. if (!predict::utils::FindNodeInMap(node_indexs_, input_node)) {
  201. // init every parameter node
  202. node_indexs_[input_node.get()] = graph_index_;
  203. graph_index_++;
  204. }
  205. } else {
  206. MS_LOG(INFO) << "This node is anfnode, no need to handle, continue. node info: " << input_node->ToString();
  207. continue;
  208. }
  209. }
  210. MS_LOG(DEBUG) << "inputs GraphIndex: " << graph_index_;
  211. return true;
  212. }
  213. bool Kernel2Ms::InitGraphValueNodesIndx(const KernelGraphPtr &kernel_graph_ptr) {
  214. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  215. if (kernel_graph_ptr->value_nodes().empty()) {
  216. return false;
  217. }
  218. for (auto &item : kernel_graph_ptr->value_nodes()) {
  219. if (item.first->isa<ValueNode>()) {
  220. auto value_node = item.first->cast<ValueNodePtr>();
  221. MS_EXCEPTION_IF_NULL(value_node);
  222. if (value_node == nullptr) {
  223. MS_LOG(WARNING) << "value_node is nullptr";
  224. return false;
  225. }
  226. if (value_node->value() == nullptr) {
  227. MS_LOG(ERROR) << "Constant value is null.";
  228. return false;
  229. }
  230. if (!value_node->value()->isa<tensor::Tensor>()) {
  231. continue;
  232. }
  233. if (!predict::utils::FindNodeInMap(node_indexs_, item.first)) {
  234. // init node
  235. auto node_ptr = item.first;
  236. node_indexs_[node_ptr.get()] = graph_index_;
  237. graph_index_++;
  238. }
  239. }
  240. }
  241. return true;
  242. }
  243. bool Kernel2Ms::InitGraphOpsIndx(const KernelGraphPtr &kernel_graph_ptr) {
  244. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  245. auto kernels = kernel_graph_ptr->execution_order();
  246. if (kernels.empty()) {
  247. MS_LOG(WARNING) << "this graph has no kernel";
  248. return false;
  249. }
  250. for (size_t i = 0; i < kernels.size(); ++i) {
  251. // for each kernel's inputs foreach real_input
  252. if (kernels[i]->isa<CNode>()) {
  253. if (!predict::utils::FindNodeInMap(node_indexs_, kernels[i])) {
  254. // init node
  255. node_indexs_[kernels[i].get()] = graph_index_;
  256. graph_index_++;
  257. }
  258. }
  259. }
  260. return true;
  261. }
  262. bool Kernel2Ms::InitGraphOutputsIndx(const KernelGraphPtr &kernel_graph_ptr) {
  263. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  264. // graph output && their inputs should link together
  265. auto out_nodes = kernel_graph_ptr->outputs();
  266. if (out_nodes.empty()) {
  267. MS_LOG(ERROR) << "this graph has no outputs";
  268. return false;
  269. }
  270. for (auto &item : out_nodes) {
  271. if (!predict::utils::FindNodeInMap(node_indexs_, item)) {
  272. node_indexs_[item.get()] = graph_index_;
  273. graph_index_++;
  274. }
  275. }
  276. return true;
  277. }
  278. bool Kernel2Ms::InitGraphIndx(const KernelGraphPtr &kernel_graph_ptr) {
  279. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  280. // only parameter
  281. if (!InitGraphInputsIndx(kernel_graph_ptr)) {
  282. return false;
  283. }
  284. // init value node
  285. if (!InitGraphValueNodesIndx(kernel_graph_ptr)) {
  286. return false;
  287. }
  288. // init op
  289. if (!InitGraphOpsIndx(kernel_graph_ptr)) {
  290. return false;
  291. }
  292. // init Graphoutput attention: out_put nodes have inputs
  293. return InitGraphOutputsIndx(kernel_graph_ptr);
  294. }
  295. bool Kernel2Ms::SetGraphInputTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache,
  296. SubGraphDefT *ms_graph) {
  297. MS_EXCEPTION_IF_NULL(tensor_cache);
  298. MS_EXCEPTION_IF_NULL(ms_graph);
  299. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  300. if (convert_mode_ == kConvertUnused) {
  301. return false;
  302. }
  303. if (kernel_graph_ptr->inputs().empty()) {
  304. return false;
  305. }
  306. for (const auto &input_node : kernel_graph_ptr->inputs()) {
  307. if (input_node->isa<Parameter>()) {
  308. ParameterPtr pk_node = std::dynamic_pointer_cast<Parameter>(input_node);
  309. TensorPtr device_tensor;
  310. if (convert_mode_ == kConvertCpuMode) {
  311. device_tensor = predict::utils::GetParaCpuTensor(input_node);
  312. } else {
  313. device_tensor = predict::utils::GetParaAscendTensor(input_node);
  314. }
  315. if (device_tensor == nullptr) {
  316. return false;
  317. }
  318. ExTensorType node_type;
  319. if (AnfAlgo::IsParameterWeight(pk_node)) {
  320. node_type = WEIGHTS;
  321. } else {
  322. node_type = INPUTDATA;
  323. }
  324. if (!predict::utils::FindNodeInMap(node_indexs_, input_node)) {
  325. MS_LOG(WARNING) << "can not find any pk_key in inited node_indexs map";
  326. return false;
  327. }
  328. auto pk_key = node_indexs_[input_node.get()];
  329. all_output_tensors_[pk_key].push_back(device_tensor);
  330. int nodeRefCount = SizeToInt(AnfAlgo::GetOutputTensorNum(input_node));
  331. int nodeInputIdx =
  332. tensor_cache->addExTensor(pk_key, device_tensor, nodeRefCount, device_tensor->shape(), node_type);
  333. if (!AnfAlgo::IsParameterWeight(pk_node)) {
  334. ms_graph->inputIndex.push_back(nodeInputIdx);
  335. all_input_idxs_.push_back(nodeInputIdx);
  336. } else {
  337. input_weight_idxs_.push_back(nodeInputIdx);
  338. all_input_idxs_.push_back(nodeInputIdx);
  339. }
  340. }
  341. }
  342. return true;
  343. }
  344. bool Kernel2Ms::SetGraphValueTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache) {
  345. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  346. MS_EXCEPTION_IF_NULL(tensor_cache);
  347. for (auto &item : kernel_graph_ptr->value_nodes()) {
  348. if (item.first->isa<ValueNode>()) {
  349. auto const_node = item.first->cast<ValueNodePtr>();
  350. auto tensor_constant = predict::utils::GetValueTensor(const_node);
  351. if (tensor_constant == nullptr) {
  352. continue;
  353. }
  354. if (!predict::utils::FindNodeInMap(node_indexs_, item.first)) {
  355. MS_LOG(WARNING) << "can not find any pk_key in inited node_indexs map";
  356. return false;
  357. }
  358. int constant_key = node_indexs_[(item.first).get()];
  359. all_output_tensors_[constant_key].push_back(tensor_constant);
  360. auto shape = tensor_constant->shape();
  361. (void)tensor_cache->addExTensor(constant_key, tensor_constant, 0, shape, CONSTANT);
  362. }
  363. }
  364. return true;
  365. }
  366. bool Kernel2Ms::SetGraphOpTensors(const KernelGraphPtr &kernel_graph_ptr, const TensorCachePtr &tensor_cache,
  367. SubGraphDefT *ms_graph) {
  368. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  369. MS_EXCEPTION_IF_NULL(tensor_cache);
  370. MS_EXCEPTION_IF_NULL(ms_graph);
  371. auto kernels = kernel_graph_ptr->execution_order();
  372. if (kernels.empty()) {
  373. MS_LOG(ERROR) << "this graph has no kernels";
  374. return false;
  375. }
  376. for (auto &kernel : kernels) {
  377. if (!predict::utils::FindNodeInMap(node_indexs_, kernel)) {
  378. MS_LOG(ERROR) << "can not find any pk_key in inited node_indexs map";
  379. return false;
  380. }
  381. auto kernel_key = node_indexs_[kernel.get()];
  382. std::unique_ptr<OpDefT> ms_node(new OpDefT);
  383. ms_node->name = kernel->fullname_with_scope();
  384. ms_node->fmkType = mindspore::predict::FmkType_CAFFE;
  385. auto c_name = AnfAlgo::GetCNodeName(kernel);
  386. auto fun = predict::convert::OpAttrFactory::GetInstance()->GetPackFun(c_name);
  387. if (fun == nullptr) {
  388. MS_LOG(WARNING) << "get node [" << kernel->fullname_with_scope() << "] attr failed.";
  389. } else if (!fun(kernel, ms_node.get())) {
  390. MS_LOG(ERROR) << "set node [" << kernel->fullname_with_scope() << "] attr failed.";
  391. return false;
  392. }
  393. auto output_size = AnfAlgo::GetOutputTensorNum(kernel);
  394. int nodeRefCount = SizeToInt(output_size);
  395. for (size_t j = 0; j < output_size; ++j) {
  396. TensorPtr device_tensor;
  397. if (convert_mode_ == kConvertCpuMode) {
  398. device_tensor = predict::utils::GetKernelCpuTensor(kernel, j);
  399. } else if (convert_mode_ == kConvertAscendMode) {
  400. device_tensor = predict::utils::GetKernelAscendTensor(kernel, j);
  401. }
  402. if (device_tensor == nullptr) {
  403. return false;
  404. }
  405. all_output_tensors_[kernel_key].push_back(device_tensor);
  406. if (!SetOpOutputIdx(kernel, device_tensor, tensor_cache, nodeRefCount, j, ms_node.get())) {
  407. return false;
  408. }
  409. }
  410. tmp_op_nodes_.emplace_back(ms_node.release());
  411. }
  412. return true;
  413. }
  414. bool Kernel2Ms::KernelGraph2MsGraph(const KernelGraphPtr &kernel_graph_ptr) {
  415. MS_EXCEPTION_IF_NULL(kernel_graph_ptr);
  416. graph_index_ = 0;
  417. all_output_tensors_.clear();
  418. node_indexs_.clear();
  419. index_nodes_.clear();
  420. std::unique_ptr<SubGraphDefT> sub_ms_graph(new SubGraphDefT());
  421. if (!InitGraphIndx(kernel_graph_ptr)) {
  422. return false;
  423. }
  424. TransformGraphIndx();
  425. tensor_cache_ptr_ = std::make_shared<TensorCache>();
  426. // foreach node to init it's real output tensor
  427. if (!SetGraphInputTensors(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get())) {
  428. return false;
  429. }
  430. // Get KernelGraph value node
  431. if (!SetGraphValueTensors(kernel_graph_ptr, tensor_cache_ptr_)) {
  432. return false;
  433. }
  434. // Get KernelGraph apply_kernel && add opNode
  435. if (!SetGraphOpTensors(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get())) {
  436. return false;
  437. }
  438. // Get KernelGraph outputs
  439. if (!SetGraphOutputIdx(kernel_graph_ptr, tensor_cache_ptr_, sub_ms_graph.get(), &all_output_tensors_)) {
  440. return false;
  441. }
  442. auto kernels = kernel_graph_ptr->execution_order();
  443. for (size_t i = 0; i < kernels.size(); ++i) {
  444. auto ms_node = tmp_op_nodes_[i];
  445. if (!SetOpInputIdx(kernels[i], tensor_cache_ptr_, ms_node)) {
  446. return false;
  447. }
  448. std::unique_ptr<OpDefT> ms_node_tmp(ms_node);
  449. sub_ms_graph->nodes.emplace_back(std::move(ms_node_tmp));
  450. }
  451. if (!SetAllTensors(tensor_cache_ptr_, sub_ms_graph.get())) {
  452. return false;
  453. }
  454. if (!SetMemResue()) {
  455. return false;
  456. }
  457. sub_ms_graph_ = std::move(sub_ms_graph);
  458. sub_ms_graph_->name = "default_sub_graph";
  459. return true;
  460. }
  461. bool Kernel2Ms::CheckInputSizes(const std::vector<TensorPtr> &input_tensors,
  462. const std::vector<uint32_t> &all_input_idxs) {
  463. if (input_tensors.size() != all_input_idxs.size()) {
  464. MS_LOG(EXCEPTION) << "real input tensors size:" << input_tensors.size()
  465. << "not equal converted tesnors size:" << all_input_idxs.size() << "the graph has changed";
  466. }
  467. for (auto in : all_input_idxs) {
  468. if (in < sub_ms_graph_->allTensors.size()) {
  469. auto real_tensor = input_tensors[in];
  470. auto convert_dims = sub_ms_graph_->allTensors[in]->dims;
  471. auto real_dims = real_tensor->shape();
  472. if (real_dims.size() != convert_dims.size()) {
  473. return false;
  474. } else {
  475. for (size_t i = 0; i < convert_dims.size(); ++i) {
  476. if (convert_dims[i] != real_dims[i]) {
  477. return false;
  478. }
  479. }
  480. }
  481. } else {
  482. MS_LOG(EXCEPTION) << "index: " << in << "in all_input_idxs is valid";
  483. }
  484. }
  485. return true;
  486. }
  487. void Kernel2Ms::ReleaseContextRes() {
  488. tmp_op_nodes_.clear();
  489. node_indexs_.clear();
  490. index_nodes_.clear();
  491. tensor_cache_ptr_ = nullptr;
  492. all_output_tensors_.clear();
  493. }
  494. bool Kernel2Ms::KernelInput2MS(const std::vector<TensorPtr> &input_tensors) {
  495. const std::unordered_map<int, std::vector<ExTensorPtr>> &cache_tensors = tensor_cache_ptr_->GetCachedTensor();
  496. if (cache_tensors.empty()) {
  497. return false;
  498. }
  499. auto all_weights_idxs = GetAllInputWeightIdxs();
  500. auto all_input_idxs = GetAllInputIdxs();
  501. auto real_input_size = input_tensors.size();
  502. // check tensor size
  503. bool ret = CheckInputSizes(input_tensors, all_input_idxs);
  504. std::vector<uint32_t> match_to_rel_idxs;
  505. // indx order not matched,macth to it
  506. if (!ret) {
  507. for (auto idx : all_weights_idxs) {
  508. auto macth_idx = real_input_size - idx;
  509. match_to_rel_idxs.push_back(macth_idx);
  510. }
  511. } else {
  512. match_to_rel_idxs = all_weights_idxs;
  513. }
  514. if (match_to_rel_idxs.size() == all_weights_idxs.size()) {
  515. for (size_t j = 0; j < all_weights_idxs.size(); ++j) {
  516. auto cache_idx = all_weights_idxs[j];
  517. auto match_idx = match_to_rel_idxs[j];
  518. auto real_tensor = input_tensors[match_idx];
  519. auto real_size = LongToSize(real_tensor->data().nbytes());
  520. auto real_data = real_tensor->data_c();
  521. MS_EXCEPTION_IF_NULL(real_data);
  522. if (sub_ms_graph_->allTensors[cache_idx] != nullptr) {
  523. sub_ms_graph_->allTensors[cache_idx]->data.resize(real_size);
  524. }
  525. if (memcpy_s(sub_ms_graph_->allTensors[cache_idx]->data.data(), real_size, real_data, real_size) != 0) {
  526. MS_LOG(ERROR) << "KernelInput2MS memcpy_s failed";
  527. return false;
  528. }
  529. }
  530. }
  531. ReleaseContextRes();
  532. return true;
  533. }
  534. bool Kernel2Ms::SaveDeviceModel(const std::shared_ptr<GraphDefT> &new_ms_graph_ptr, const std::string &save_path_name) {
  535. MS_EXCEPTION_IF_NULL(new_ms_graph_ptr);
  536. return predict::utils::SaveDeviceModelUtil(new_ms_graph_ptr, save_path_name, sub_ms_graph_.release());
  537. }
  538. } // namespace executor
  539. } // namespace mindspore