You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_adapter.h 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef TRANSFORM_OP_ADAPTER_H_
  17. #define TRANSFORM_OP_ADAPTER_H_
  18. #include <memory>
  19. #include <vector>
  20. #include <string>
  21. #include <unordered_map>
  22. #include "transform/op_adapter_util.h"
  23. #include "utils/utils.h"
  24. namespace mindspore {
  25. namespace transform {
  26. static uint32_t CustomInferFunc(const Operator&) { return 0; }
  27. template <typename T>
  28. class OpAdapter : public BaseOpAdapter {
  29. public:
  30. using OpType = T;
  31. OpAdapter() {}
  32. explicit OpAdapter(const ExtraAttr& extra_attr) : extra_attr_(extra_attr) {}
  33. ~OpAdapter() override {}
  34. bool IsCustomOp(const OperatorPtr& op) {
  35. MS_EXCEPTION_IF_NULL(op);
  36. auto it = cus_input_map_.find(op->GetOpType());
  37. if (it == cus_input_map_.end()) {
  38. return false;
  39. }
  40. return true;
  41. }
  42. Status GenerateCustomOpInputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) {
  43. MS_EXCEPTION_IF_NULL(op);
  44. MS_EXCEPTION_IF_NULL(prim);
  45. // Create the map of custom op from input index to input name.
  46. std::unordered_map<int, std::string> input_map;
  47. auto value = prim->GetAttr("input_names");
  48. if (value == nullptr) {
  49. cus_output_map_[prim->name()] = input_map;
  50. return NOT_FOUND;
  51. }
  52. auto input_names = GetValue<const std::vector<std::string>>(value);
  53. for (size_t i = 0; i < input_names.size(); ++i) {
  54. // input_map begin form 1
  55. input_map[i + 1] = input_names[i];
  56. op->CustomInputRegister(input_names[i]);
  57. }
  58. if (cus_input_map_.find(prim->name()) == cus_input_map_.end()) {
  59. cus_input_map_[prim->name()] = input_map;
  60. }
  61. return SUCCESS;
  62. }
  63. Status GenerateCustomOpOutputMap(const CusOperatorPtr& op, const PrimitivePtr& prim) {
  64. MS_EXCEPTION_IF_NULL(op);
  65. MS_EXCEPTION_IF_NULL(prim);
  66. // Create the map of custom op from output index to output name.
  67. std::unordered_map<int, std::string> output_map;
  68. auto value = prim->GetAttr("output_names");
  69. if (value == nullptr) {
  70. // generate a empty output_map for it
  71. cus_output_map_[prim->name()] = output_map;
  72. return NOT_FOUND;
  73. }
  74. auto output_names = GetValue<const std::vector<std::string>>(value);
  75. for (size_t i = 0; i < output_names.size(); ++i) {
  76. // output_map begin form 0
  77. output_map[i] = output_names[i];
  78. op->CustomOutputRegister(output_names[i]);
  79. }
  80. if (cus_output_map_.find(prim->name()) == cus_output_map_.end()) {
  81. cus_output_map_[prim->name()] = output_map;
  82. }
  83. return SUCCESS;
  84. }
  85. // Convert ME UserCustom AnfNode to GE CustomOp. And set it's attrs.
  86. OperatorPtr GenerateCustomOp(const AnfNodePtr anf) {
  87. MS_EXCEPTION_IF_NULL(anf);
  88. auto node = anf->cast<CNodePtr>();
  89. if (node == nullptr) {
  90. return nullptr;
  91. }
  92. if (node->inputs().empty()) {
  93. MS_LOG(EXCEPTION) << "length of node inputs is empty";
  94. }
  95. auto prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
  96. MS_EXCEPTION_IF_NULL(prim);
  97. auto op = std::make_shared<ge::CustomOperator>(node->fullname_with_scope(), prim->name());
  98. if (GenerateCustomOpInputMap(op, prim) != SUCCESS) {
  99. MS_LOG(WARNING) << "Custom op node has no input_names, op[" << prim->name() << "].";
  100. }
  101. if (GenerateCustomOpOutputMap(op, prim) != SUCCESS) {
  102. MS_LOG(WARNING) << "Custom op node has no output_names, op[" << prim->name() << "].";
  103. }
  104. op->CustomInferFuncRegister(CustomInferFunc);
  105. return op;
  106. }
  107. OperatorPtr GenerateNormalOp(const AnfNodePtr& anf) {
  108. OperatorPtr op = nullptr;
  109. // There are duplicate names in ANF graph, do not assign ANF node name to GE
  110. // GE will generate unique name automatically
  111. if (anf != nullptr && anf->fullname_with_scope() != "") {
  112. MS_LOG(DEBUG) << anf->fullname_with_scope();
  113. op = std::make_shared<T>(anf->fullname_with_scope());
  114. } else {
  115. MS_LOG(DEBUG) << "no fullname_with_scope";
  116. op = std::make_shared<T>();
  117. }
  118. // set dynamic output num if op use DYNAMIC_OUTPUT
  119. if ((op != nullptr) && (!dyn_output_map_.empty()) && (anf != nullptr)) {
  120. TypePtr type = anf->Type();
  121. if (type == nullptr) {
  122. MS_LOG(EXCEPTION) << "Dynamic output node:" << op->GetName() << "'s Type is a nullptr!";
  123. }
  124. size_t num = type->isa<Tuple>() ? (type->cast<std::shared_ptr<Tuple>>()->size()) : 1;
  125. MS_LOG(INFO) << "create_dyn_output for node:" << anf->ToString() << ", type:" << type->ToString()
  126. << ", num:" << num;
  127. dyn_output_map_.begin()->second.create_dyn_output(op, static_cast<unsigned int>(num));
  128. }
  129. return op;
  130. }
  131. OperatorPtr generate(const AnfNodePtr& anf) override {
  132. OperatorPtr op = nullptr;
  133. if (IsCustomCNode(anf)) {
  134. op = GenerateCustomOp(anf);
  135. } else {
  136. op = GenerateNormalOp(anf);
  137. }
  138. return op;
  139. }
  140. OperatorPtr generate(const std::string& op_name) override { return std::make_shared<T>(op_name); }
  141. const std::unordered_map<int, InputDesc>& getInputMap() override { return input_map_; }
  142. const std::unordered_map<unsigned int, AttrDesc>& getInputAttrMap() override { return input_attr_map_; }
  143. const std::unordered_map<int, DynInputDesc>& getDynInputMap() override { return dyn_input_map_; }
  144. const std::unordered_map<int, OutputDesc>& getOutputMap() override { return output_map_; }
  145. Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OperatorPtr& input) {
  146. MS_EXCEPTION_IF_NULL(op);
  147. MS_EXCEPTION_IF_NULL(input);
  148. auto it = cus_input_map_.find(op->GetOpType());
  149. if (it == cus_input_map_.end()) {
  150. return NOT_FOUND;
  151. }
  152. std::unordered_map<int, std::string>& input_map = it->second;
  153. if ((input_map.find(index) != input_map.end())) {
  154. MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << input_map[index];
  155. (void)op->SetInput(input_map[index], *input);
  156. return SUCCESS;
  157. }
  158. return NOT_FOUND;
  159. }
  160. Status SetNormalOpInput(const OperatorPtr& op, int index, const OperatorPtr& input) {
  161. MS_EXCEPTION_IF_NULL(op);
  162. auto it = input_map_.find(index);
  163. if (it != input_map_.end()) {
  164. MS_EXCEPTION_IF_NULL(input);
  165. MS_LOG(DEBUG) << "Link op " << input->GetName() << " to " << op->GetName() << ":" << it->second.name;
  166. it->second.set_op(op, input);
  167. return SUCCESS;
  168. }
  169. return NOT_FOUND;
  170. }
  171. int setInput(const OperatorPtr& op, int index, const OperatorPtr& input) override {
  172. if (IsCustomOp(op)) {
  173. auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
  174. return static_cast<int>(SetCustomOpInput(cus_op, index, input));
  175. } else {
  176. return static_cast<int>(SetNormalOpInput(op, index, input));
  177. }
  178. }
  179. Status SetCustomOpInput(const CusOperatorPtr& op, int index, const OutHandler& handle) {
  180. MS_EXCEPTION_IF_NULL(op);
  181. auto it = cus_input_map_.find(op->GetOpType());
  182. if (it == cus_input_map_.end()) {
  183. return NOT_FOUND;
  184. }
  185. std::unordered_map<int, std::string>& input_map = it->second;
  186. if ((handle.op != nullptr) && (input_map.find(index) != input_map.end())) {
  187. if (handle.out.empty()) {
  188. MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << input_map[index];
  189. (void)op->SetInput(input_map[index], *(handle.op));
  190. } else {
  191. MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":"
  192. << input_map[index];
  193. (void)op->SetInput(input_map[index], *(handle.op), handle.out);
  194. }
  195. return SUCCESS;
  196. }
  197. return NOT_FOUND;
  198. }
  199. Status SetNormalOpInput(const OperatorPtr& op, int index, const OutHandler& handle) {
  200. MS_EXCEPTION_IF_NULL(op);
  201. auto it = input_map_.find(index);
  202. if ((handle.op != nullptr) && (it != input_map_.end())) {
  203. if (handle.out.empty()) {
  204. MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << " to " << op->GetName() << ":" << it->second.name;
  205. it->second.set_op(op, handle.op);
  206. } else {
  207. MS_LOG(DEBUG) << "Link op " << handle.op->GetName() << ":" << handle.out << " to " << op->GetName() << ":"
  208. << it->second.name;
  209. it->second.set_handle(op, handle);
  210. }
  211. return SUCCESS;
  212. }
  213. return NOT_FOUND;
  214. }
  215. int setInput(const OperatorPtr& op, int index, const OutHandler& handle) override {
  216. if (IsCustomOp(op)) {
  217. auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
  218. return static_cast<int>(SetCustomOpInput(cus_op, index, handle));
  219. } else {
  220. return static_cast<int>(SetNormalOpInput(op, index, handle));
  221. }
  222. }
  223. int setInput(const OperatorPtr& op, int index, const std::shared_ptr<std::vector<OutHandler>>& handler_vec) override {
  224. MS_EXCEPTION_IF_NULL(handler_vec);
  225. if (IsCustomOp(op)) {
  226. MS_LOG(ERROR) << "Custom Op do not support dynamic input";
  227. return static_cast<int>(FAILED);
  228. }
  229. MS_EXCEPTION_IF_NULL(op);
  230. auto it = dyn_input_map_.find(index);
  231. if (it != dyn_input_map_.end()) {
  232. it->second.create_dyn_input(op, static_cast<unsigned int>(handler_vec->size()));
  233. for (unsigned int i = 0; i < handler_vec->size(); ++i) {
  234. OutHandler h = (*handler_vec)[i];
  235. MS_EXCEPTION_IF_NULL(h.op);
  236. if (h.out.empty()) {
  237. MS_LOG(DEBUG) << "Link op " << h.op->GetName() << " to " << op->GetName() << ":" << it->second.name;
  238. it->second.set_op(op, (i) /* index start from 0 */, h.op);
  239. } else {
  240. MS_LOG(DEBUG) << "Link op " << h.op->GetName() << ":" << h.out << " to " << op->GetName() << ":"
  241. << it->second.name;
  242. it->second.set_handle(op, i, h);
  243. }
  244. }
  245. return 0;
  246. }
  247. return static_cast<int>(NOT_FOUND);
  248. }
  249. OutHandler getOutput(const OperatorPtr& op, int index) override {
  250. MS_EXCEPTION_IF_NULL(op);
  251. if (IsCustomOp(op)) {
  252. return getCustomOutput(op, index);
  253. }
  254. return getNormalOutput(op, index);
  255. }
  256. OutHandler getCustomOutput(const OperatorPtr& op, int index) {
  257. MS_EXCEPTION_IF_NULL(op);
  258. auto it = cus_output_map_.find(op->GetOpType());
  259. if (it == cus_output_map_.end()) {
  260. MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT is not supported!";
  261. return OutHandler();
  262. }
  263. std::unordered_map<int, std::string>& output_map = it->second;
  264. if ((output_map.find(index) != output_map.end())) {
  265. return OutHandler(op, output_map[index]);
  266. }
  267. MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT index(" << index << ")!";
  268. return OutHandler();
  269. }
  270. OutHandler getNormalOutput(const OperatorPtr& op, int index) {
  271. MS_EXCEPTION_IF_NULL(op);
  272. if (!dyn_output_map_.empty() && !output_map_.empty()) {
  273. MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has both OUTPUT and DYN_OUTPUT is not supported!";
  274. return OutHandler();
  275. }
  276. auto it = output_map_.find(index);
  277. if (it != output_map_.end()) {
  278. return OutHandler(op, it->second.name);
  279. } else if (!dyn_output_map_.empty()) {
  280. return OutHandler(op, dyn_output_map_.begin()->second.name + std::to_string(index));
  281. } else {
  282. MS_LOG(ERROR) << "OpAdpator(" << op->GetName() << ") has no OUTPUT and DYN_OUTPUT index(" << index << ")!";
  283. return OutHandler();
  284. }
  285. }
  286. Status UpdateSingleOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) {
  287. MS_EXCEPTION_IF_NULL(type);
  288. std::string format = "NCHW";
  289. if (op->GetOpType() == kExtractImagePatchesOpName) {
  290. format = "NHWC";
  291. }
  292. auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(shp), type, format);
  293. if (desc == nullptr) {
  294. MS_LOG(ERROR) << "Update output descriptor failed!";
  295. return FAILED;
  296. }
  297. if (IsCustomOp(op)) {
  298. if (cus_output_map_.find(op->GetOpType()) == cus_output_map_.end() ||
  299. (cus_output_map_[op->GetOpType()].empty())) {
  300. MS_LOG(ERROR) << "This op does not create custom output map";
  301. return FAILED;
  302. }
  303. auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
  304. MS_EXCEPTION_IF_NULL(cus_op);
  305. std::unordered_map<int, std::string> output_map = cus_output_map_[op->GetOpType()];
  306. (void)cus_op->UpdateOutputDesc(output_map[0], *desc);
  307. } else {
  308. if (output_map_.empty()) {
  309. MS_LOG(INFO) << "This op does not have output map";
  310. return FAILED;
  311. }
  312. output_map_.begin()->second.update_out_desc(op, *desc);
  313. }
  314. return SUCCESS;
  315. }
  316. size_t GetCustomOpOutputSize(const CusOperatorPtr& cus_op) {
  317. MS_EXCEPTION_IF_NULL(cus_op);
  318. if (cus_output_map_.find(cus_op->GetOpType()) == cus_output_map_.end()) {
  319. MS_LOG(ERROR) << "This op does not create custom output map";
  320. return 0;
  321. }
  322. size_t output_size = cus_output_map_[cus_op->GetOpType()].size();
  323. return output_size;
  324. }
  325. std::shared_ptr<GeTensorDesc> CreateOutputDesc(const abstract::ShapePtr& shape_ptr, const TypePtr& type,
  326. const std::string& format) {
  327. if (shape_ptr == nullptr) {
  328. MS_LOG(ERROR) << "Shape ptr is nullptr";
  329. return nullptr;
  330. }
  331. if (type == nullptr) {
  332. MS_LOG(ERROR) << "Type ptr is nullptr";
  333. return nullptr;
  334. }
  335. TypeId me_type = type->type_id();
  336. if (kObjectTypeTensorType == me_type) {
  337. me_type = dyn_cast<TensorType>(type)->element()->type_id();
  338. }
  339. auto desc = TransformUtil::GetGeTensorDesc(shape_ptr->shape(), me_type, format);
  340. return desc;
  341. }
  342. Status UpdateMultiOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type) {
  343. auto tuple_shp = dyn_cast<abstract::TupleShape>(shp);
  344. MS_EXCEPTION_IF_NULL(tuple_shp);
  345. size_t output_size = 0;
  346. bool is_custom_op = IsCustomOp(op);
  347. if (is_custom_op) {
  348. output_size = GetCustomOpOutputSize(std::dynamic_pointer_cast<CustomOperator>(op));
  349. } else {
  350. output_size = output_map_.size();
  351. }
  352. if (output_size == 0) {
  353. MS_LOG(INFO) << "This op does not have output map";
  354. return FAILED;
  355. }
  356. if (output_size != tuple_shp->shape().size()) {
  357. MS_LOG(ERROR) << "output_map is not equal tuple_shape size";
  358. return FAILED;
  359. }
  360. std::string format = "NCHW";
  361. if (op->GetOpType() == kTopKOpName) {
  362. format = "NHWC";
  363. }
  364. for (size_t i = 0; i < tuple_shp->shape().size(); ++i) {
  365. auto tuple_type = dyn_cast<Tuple>(type);
  366. MS_EXCEPTION_IF_NULL(tuple_type);
  367. TypePtr type_elem = tuple_type->elements()[i];
  368. auto desc = CreateOutputDesc(dyn_cast<abstract::Shape>(tuple_shp->shape()[i]), type_elem, format);
  369. if (desc == nullptr) {
  370. MS_LOG(ERROR) << "Create output descriptor failed!";
  371. return FAILED;
  372. }
  373. if (is_custom_op) {
  374. (void)std::dynamic_pointer_cast<CustomOperator>(op)->UpdateOutputDesc(cus_output_map_[op->GetOpType()][i],
  375. *desc);
  376. } else {
  377. auto it = output_map_.find(i);
  378. if (it != output_map_.end()) {
  379. it->second.update_out_desc(op, *desc);
  380. }
  381. }
  382. }
  383. return SUCCESS;
  384. }
  385. std::shared_ptr<GeTensorDesc> CreateNodeDesc(const AnfNodePtr& node) {
  386. MS_EXCEPTION_IF_NULL(node);
  387. TypeId me_type = node->Type()->type_id();
  388. if (kObjectTypeTensorType == me_type) {
  389. me_type = dyn_cast<TensorType>(node->Type())->element()->type_id();
  390. }
  391. if (me_type <= kNumberTypeBegin || me_type >= kNumberTypeEnd) {
  392. return nullptr;
  393. }
  394. std::vector<int> shape;
  395. auto shape_ptr = dyn_cast<abstract::Shape>(node->Shape());
  396. if (nullptr != shape_ptr) {
  397. shape = shape_ptr->shape();
  398. }
  399. auto desc = TransformUtil::GetGeTensorDesc(shape, me_type, "NCHW");
  400. if (desc == nullptr) {
  401. MS_LOG(ERROR) << "Update output descriptor failed!";
  402. return nullptr;
  403. }
  404. return desc;
  405. }
  406. void UpdateNormalOpInputDesc(const OperatorPtr& op, const AnfNodePtr node) {
  407. if (op == nullptr) {
  408. MS_LOG(ERROR) << "op is nullptr";
  409. return;
  410. }
  411. MS_EXCEPTION_IF_NULL(node);
  412. auto inputs = node->cast<CNodePtr>()->inputs();
  413. for (size_t i = 1; i < inputs.size(); ++i) {
  414. auto it = input_map_.find(i);
  415. if (it != input_map_.end()) {
  416. auto desc = CreateNodeDesc(inputs[i]);
  417. if (desc == nullptr) {
  418. continue;
  419. }
  420. if (op->GetOpType() == kExtractImagePatchesOpName) {
  421. desc->SetFormat(ge::Format::FORMAT_NHWC);
  422. }
  423. it->second.update_input_desc(op, *desc);
  424. }
  425. }
  426. }
  427. void UpdateCustomOpInputDesc(const CusOperatorPtr& op, const AnfNodePtr& node) {
  428. if (op == nullptr) {
  429. MS_LOG(ERROR) << "op is nullptr";
  430. return;
  431. }
  432. MS_EXCEPTION_IF_NULL(node);
  433. if (cus_input_map_.find(op->GetOpType()) == cus_input_map_.end() || (cus_input_map_[op->GetOpType()].empty())) {
  434. MS_LOG(ERROR) << "This op does not create custom input map";
  435. return;
  436. }
  437. std::unordered_map<int, std::string>& input_map = cus_input_map_[op->GetOpType()];
  438. auto inputs = node->cast<CNodePtr>()->inputs();
  439. for (size_t i = 1; i < inputs.size(); ++i) {
  440. if (input_map.find(i) != input_map.end()) {
  441. auto desc = CreateNodeDesc(inputs[i]);
  442. if (desc == nullptr) {
  443. continue;
  444. }
  445. (void)op->UpdateInputDesc(input_map[i], *desc);
  446. }
  447. }
  448. }
  449. void updateInputDesc(const OperatorPtr& op, const AnfNodePtr& node) {
  450. MS_EXCEPTION_IF_NULL(op);
  451. MS_EXCEPTION_IF_NULL(node);
  452. if (IsCustomOp(op)) {
  453. auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
  454. UpdateCustomOpInputDesc(cus_op, node);
  455. } else {
  456. UpdateNormalOpInputDesc(op, node);
  457. }
  458. }
  459. void updateOutputDesc(const OperatorPtr& op, const abstract::BaseShapePtr& shp, const TypePtr& type,
  460. const AnfNodePtr& node) override {
  461. if (op == nullptr) {
  462. MS_LOG(ERROR) << "op is nullptr";
  463. return;
  464. }
  465. MS_EXCEPTION_IF_NULL(node);
  466. MS_LOG(INFO) << "Op name is " << op->GetName();
  467. auto normal_shape_ptr = dyn_cast<abstract::Shape>(shp);
  468. auto no_shape_ptr = dyn_cast<abstract::NoShape>(shp);
  469. if ((nullptr != normal_shape_ptr) || (nullptr != no_shape_ptr)) {
  470. if (UpdateSingleOutputDesc(op, shp, type) != SUCCESS) {
  471. return;
  472. }
  473. } else if (nullptr != dyn_cast<abstract::TupleShape>(shp)) {
  474. if (UpdateMultiOutputDesc(op, shp, type) != SUCCESS) {
  475. return;
  476. }
  477. } else {
  478. MS_LOG(WARNING) << "Update output desc failed, unknow output shape type";
  479. return;
  480. }
  481. MS_EXCEPTION_IF_NULL(node);
  482. if (!node->isa<CNode>()) {
  483. return;
  484. }
  485. // Need to update input_desc while the output_desc is updated
  486. updateInputDesc(op, node);
  487. }
  488. int setAttr(const OperatorPtr& op, const std::string& attrKey, const ValuePtr& attrValue) override {
  489. auto it = attr_map_.find(attrKey);
  490. if (it != attr_map_.end()) {
  491. // switch case for each avalilable attribute type
  492. MS_LOG(INFO) << "Set attr: " << attrKey << "(" << it->second.name << "), value: " << attrValue->ToString();
  493. AddAttrToDrawGraph(attrKey + std::string("=") + attrValue->ToString());
  494. it->second.set_attr(op, attrValue);
  495. return 0;
  496. }
  497. return static_cast<int>(NOT_FOUND);
  498. }
  499. int SetCustomOpAttr(const CusOperatorPtr& op, const PrimitivePtr& prim) {
  500. enum ValueType {
  501. SINGLE_VALUE = 0,
  502. SEQUEUE_VALUE,
  503. UNKNOWN_VALUE,
  504. };
  505. MS_EXCEPTION_IF_NULL(prim);
  506. MS_EXCEPTION_IF_NULL(op);
  507. ValueType value_type = SINGLE_VALUE;
  508. for (auto item : prim->attrs()) {
  509. if (item.second->isa<Int32Imm>()) {
  510. (void)op->SetAttr(item.first, GetValue<int>(item.second));
  511. } else if (item.second->isa<StringImm>()) {
  512. (void)op->SetAttr(item.first, GetValue<std::string>(item.second));
  513. } else if (item.second->isa<BoolImm>()) {
  514. (void)op->SetAttr(item.first, GetValue<bool>(item.second));
  515. } else if (item.second->isa<FP32Imm>()) {
  516. (void)op->SetAttr(item.first, GetValue<float>(item.second));
  517. } else if (item.second->isa<ValueSequeue>()) {
  518. value_type = SEQUEUE_VALUE;
  519. auto val_seq = item.second->cast<ValueSequeuePtr>();
  520. if ((*val_seq)[0]->isa<StringImm>()) {
  521. (void)op->SetAttr(item.first, GetValue<const std::vector<std::string>>(item.second));
  522. } else if ((*val_seq)[0]->isa<FP32Imm>()) {
  523. (void)op->SetAttr(item.first, GetValue<const std::vector<float>>(item.second));
  524. } else if ((*val_seq)[0]->isa<Int32Imm>()) {
  525. (void)op->SetAttr(item.first, GetValue<const std::vector<int>>(item.second));
  526. } else if ((*val_seq)[0]->isa<BoolImm>()) {
  527. (void)op->SetAttr(item.first, GetValue<const std::vector<bool>>(item.second));
  528. } else {
  529. MS_LOG(EXCEPTION) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name()
  530. << ", attr name: " << item.first << ", value: " << item.second->ToString();
  531. }
  532. } else {
  533. value_type = UNKNOWN_VALUE;
  534. MS_LOG(WARNING) << "Unsupported custom attribute type in adaptor, prim name: " << prim->name()
  535. << ", attr name: " << item.first << ", value: " << item.second->ToString();
  536. return static_cast<int>(NOT_FOUND);
  537. }
  538. if (value_type == SINGLE_VALUE) {
  539. AddAttrToDrawGraph(item.first + std::string("=") + item.second->ToString());
  540. } else if (value_type == SEQUEUE_VALUE) {
  541. AddAttrToDrawGraph(item.first + std::string("=") + "[...]");
  542. }
  543. }
  544. return 0;
  545. }
  546. int SetNormalOpAttr(const OperatorPtr& op, const PrimitivePtr& prim) {
  547. int ret = 0;
  548. MS_EXCEPTION_IF_NULL(prim);
  549. MS_EXCEPTION_IF_NULL(op);
  550. for (auto& it : attr_map_) {
  551. auto value = prim->GetAttr(it.first);
  552. if (value != nullptr) {
  553. // set attr from primitive
  554. ret = setAttr(op, it.first, value);
  555. if (ret) {
  556. return ret;
  557. }
  558. } else {
  559. // set attr from extra_attr
  560. auto it_extra = extra_attr_.find(it.first);
  561. if (it_extra != extra_attr_.end()) {
  562. ret = setAttr(op, it.first, it_extra->second);
  563. if (ret) {
  564. return ret;
  565. }
  566. }
  567. }
  568. }
  569. return 0;
  570. }
  571. int setAttr(const OperatorPtr& op, const PrimitivePtr& prim) override {
  572. int ret = 0;
  573. if (IsCustomPrim(prim)) {
  574. auto cus_op = std::dynamic_pointer_cast<CustomOperator>(op);
  575. ret = SetCustomOpAttr(cus_op, prim);
  576. } else {
  577. ret = SetNormalOpAttr(op, prim);
  578. }
  579. return ret;
  580. }
  581. int setAttr(const OperatorPtr& op, const AnfNodePtr& node) override {
  582. // no attribute for lonely node
  583. MS_EXCEPTION_IF_NULL(node);
  584. if (!node->isa<CNode>()) {
  585. return 0;
  586. }
  587. auto cnode = node->cast<CNodePtr>();
  588. if (cnode == nullptr) {
  589. return 0;
  590. }
  591. auto& inputs = cnode->inputs();
  592. if (inputs.empty()) {
  593. return 0;
  594. }
  595. // get Attr T from abstract of anfnode first,
  596. // if attr "T" appears in primitive, the primitive T will cover this one
  597. if (attr_map_.find("T") != attr_map_.end()) {
  598. // get dtype from inputs[1], if the node has no inputs, set the attr T with output dtype
  599. TypePtr type;
  600. if (inputs.size() > 1) {
  601. type = inputs[1]->Type();
  602. } else {
  603. type = node->Type();
  604. }
  605. if (type != nullptr) {
  606. (void)setAttr(op, "T", MakeValue(type));
  607. }
  608. }
  609. // set attr from primitive and ExtraAttr
  610. if (IsValueNode<Primitive>(inputs[0])) {
  611. // set attr from primitive
  612. PrimitivePtr prim = GetValueNode<PrimitivePtr>(inputs[0]);
  613. int ret = setAttr(op, prim);
  614. if (ret != 0) {
  615. return ret;
  616. }
  617. }
  618. // set attr from const input
  619. for (auto& it : input_attr_map_) {
  620. if (inputs.size() <= it.first || !inputs[it.first]->isa<ValueNode>()) {
  621. continue;
  622. }
  623. auto const_value = GetValueNode(inputs[it.first]);
  624. MS_LOG(INFO) << "Set attr: input_" << it.first << "(" << it.second.name
  625. << "), value: " << const_value->ToString();
  626. if (const_value->isa<None>()) {
  627. continue;
  628. }
  629. AddAttrToDrawGraph(it.second.name + std::string("=") + const_value->ToString());
  630. it.second.set_attr(op, const_value);
  631. }
  632. return 0;
  633. }
  634. std::unordered_map<std::string, ValuePtr> GetExtraAttr() override { return extra_attr_; }
  635. private:
  636. template <typename S>
  637. static S ConvertAny(const ValuePtr& value, const AnyTraits<S>&) {
  638. return GetValue<S>(value);
  639. }
  640. // specialization for reverse bool
  641. static bool ConvertAny(const ValuePtr& value, const AnyTraits<bool>&, bool reverse) {
  642. return reverse != GetValue<bool>(value);
  643. }
  644. template <typename P, typename Q>
  645. static Q ConvertAny(const ValuePtr& value, const AnyTraits<P>& traits_from, const AnyTraits<Q>& traits_to) {
  646. return ConvertAnyUtil(value, traits_from, traits_to);
  647. }
  648. // specialization for tensor
  649. static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits<mindspore::tensor::Tensor>& traits) {
  650. // To-DO the format may read from ME tensor
  651. return ConvertAnyUtil(value, traits);
  652. }
  653. // specialization for int
  654. static int64_t ConvertAny(const ValuePtr& value, const AnyTraits<int64_t>) {
  655. return static_cast<int64_t>(GetValue<int>(value));
  656. }
  657. // specialization for int to Vector
  658. static std::vector<int64_t> ConvertAny(const ValuePtr& value, const std::string& name,
  659. const AnyTraits<std::vector<int64_t>> anyTraitsInt) {
  660. return ConvertAnyUtil(value, name, anyTraitsInt);
  661. }
  662. static std::vector<std::vector<int64_t>> ConvertAny(const ValuePtr& value,
  663. const AnyTraits<std::vector<std::vector<int64_t>>>) {
  664. MS_EXCEPTION_IF_NULL(value);
  665. MS_LOG(INFO) << "Value: " << value->type_name();
  666. std::vector<std::vector<int64_t>> list;
  667. if (!value->isa<ValueTuple>()) {
  668. MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got " << value->type_name();
  669. }
  670. auto vec = value->cast<ValueTuplePtr>();
  671. MS_EXCEPTION_IF_NULL(vec);
  672. for (auto& it : vec->value()) {
  673. MS_EXCEPTION_IF_NULL(it);
  674. if (!it->isa<ValueTuple>()) {
  675. MS_LOG(EXCEPTION) << "It should be ValueTuple, but got " << it->type_name();
  676. }
  677. auto sub_vector = it->cast<ValueTuplePtr>();
  678. std::vector<int64_t> sublist;
  679. for (auto& item : sub_vector->value()) {
  680. sublist.push_back(static_cast<int64_t>(GetValue<int>(item)));
  681. }
  682. list.push_back(sublist);
  683. }
  684. return list;
  685. }
  686. static std::vector<int64_t> ConvertAny(const ValuePtr& value, const AnyTraits<std::vector<std::vector<int64_t>>>,
  687. const AnyTraits<std::vector<int64_t>>) {
  688. MS_EXCEPTION_IF_NULL(value);
  689. MS_LOG(DEBUG) << "Value: " << value->type_name();
  690. if (!value->isa<ValueList>()) {
  691. MS_LOG(EXCEPTION) << "Value should be ValueList, but got " << value->type_name();
  692. }
  693. auto vec = value->cast<ValueListPtr>();
  694. std::vector<int64_t> list;
  695. for (auto& it : vec->value()) {
  696. MS_EXCEPTION_IF_NULL(it);
  697. if (!it->isa<ValueList>()) {
  698. MS_LOG(EXCEPTION) << "It should be ValueList, but got " << it->type_name();
  699. }
  700. auto sub_vector = it->cast<ValueListPtr>();
  701. for (auto& item : sub_vector->value()) {
  702. list.push_back(static_cast<int64_t>(GetValue<int>(item)));
  703. }
  704. }
  705. return list;
  706. }
  707. static std::vector<int64_t> ConvertAny(const ValuePtr& value, const AnyTraits<std::vector<int64_t>>,
  708. const AnyTraits<std::vector<int64_t>>) {
  709. MS_EXCEPTION_IF_NULL(value);
  710. MS_LOG(INFO) << "Value: " << value->type_name();
  711. std::vector<int64_t> list;
  712. if (value->isa<ValueSequeue>()) {
  713. auto vec = value->cast<ValueSequeuePtr>();
  714. MS_EXCEPTION_IF_NULL(vec);
  715. for (auto& it : vec->value()) {
  716. list.push_back(static_cast<int64_t>(GetValue<int>(it)));
  717. }
  718. return list;
  719. }
  720. if (value->isa<Scalar>()) {
  721. list.push_back(static_cast<int64_t>(GetValue<int>(value)));
  722. return list;
  723. }
  724. MS_LOG(EXCEPTION) << "Value should be ValueTuple or Scalar, but got " << value->type_name();
  725. }
  726. static std::string ConvertAny(const ValuePtr& value, const AnyTraits<std::vector<int64_t>> anyTraitsVec,
  727. const AnyTraits<std::string> anyTraitsStr) {
  728. return ConvertAnyUtil(value, anyTraitsVec, anyTraitsStr);
  729. }
  730. static std::vector<float> ConvertAny(const ValuePtr& value, const AnyTraits<std::vector<float>> anyTraitsVec,
  731. const AnyTraits<float> anyTraitsFlo) {
  732. return ConvertAnyUtil(value, anyTraitsVec, anyTraitsFlo);
  733. }
  734. static std::vector<int64_t> ConvertAny(const ValuePtr& value, const std::string& format,
  735. const AnyTraits<std::vector<int64_t>> anyTraitsVec,
  736. const AnyTraits<int64_t> anyTraitsInt) {
  737. return ConvertAnyUtil(value, format, anyTraitsVec, anyTraitsInt);
  738. }
  739. // convert value list for value tuple to vector
  740. template <typename P, typename Q>
  741. static std::vector<Q> ConvertAny(const ValuePtr& value, const AnyTraits<P>& anyTraitsP,
  742. const AnyTraits<std::vector<Q>> anyTraitsQ) {
  743. return ConvertAnyUtil(value, anyTraitsP, anyTraitsQ);
  744. }
  745. static int64_t ConvertAny(const ValuePtr& value, const AnyTraits<GeEnum>) {
  746. auto name = GetValue<std::string>(value);
  747. auto it = enum_map_.find(name);
  748. int v = 0;
  749. if (it != enum_map_.end()) {
  750. v = it->second;
  751. }
  752. return v;
  753. }
  754. static GeDataType ConvertAny(const ValuePtr& value, const AnyTraits<GEType> anyTraitsGE) {
  755. return ConvertAnyUtil(value, anyTraitsGE);
  756. }
  757. // convert any value to tensor
  758. static GeTensor ConvertAny(const ValuePtr& value, const AnyTraits<AnyValue> anyTraitsValue) {
  759. return ConvertAnyUtil(value, anyTraitsValue);
  760. }
  761. static const std::unordered_map<int, InputDesc> input_map_;
  762. static const std::unordered_map<int, DynInputDesc> dyn_input_map_;
  763. static const std::unordered_map<int, OutputDesc> output_map_;
  764. static const std::unordered_map<int, DynOutputDesc> dyn_output_map_;
  765. static const std::unordered_map<std::string, AttrDesc> attr_map_;
  766. static const std::unordered_map<std::string, int> enum_map_;
  767. // convert input from anf graph to Attr in Operators
  768. static const std::unordered_map<unsigned int, AttrDesc> input_attr_map_;
  769. static std::unordered_map<std::string, std::unordered_map<int, std::string>> cus_input_map_;
  770. static std::unordered_map<std::string, std::unordered_map<int, std::string>> cus_output_map_;
  771. std::unordered_map<std::string, ValuePtr> extra_attr_;
  772. std::unordered_map<std::string, int> name_counts_;
  773. };
  774. template <typename T>
  775. const std::unordered_map<int, InputDesc> OpAdapter<T>::input_map_;
  776. template <typename T>
  777. const std::unordered_map<int, DynInputDesc> OpAdapter<T>::dyn_input_map_;
  778. template <typename T>
  779. const std::unordered_map<int, OutputDesc> OpAdapter<T>::output_map_;
  780. template <typename T>
  781. const std::unordered_map<int, DynOutputDesc> OpAdapter<T>::dyn_output_map_;
  782. template <typename T>
  783. const std::unordered_map<std::string, AttrDesc> OpAdapter<T>::attr_map_;
  784. template <typename T>
  785. const std::unordered_map<std::string, int> OpAdapter<T>::enum_map_;
  786. template <typename T>
  787. const std::unordered_map<unsigned int, AttrDesc> OpAdapter<T>::input_attr_map_;
  788. template <typename T>
  789. std::unordered_map<std::string, std::unordered_map<int, std::string>> OpAdapter<T>::cus_input_map_;
  790. template <typename T>
  791. std::unordered_map<std::string, std::unordered_map<int, std::string>> OpAdapter<T>::cus_output_map_;
  792. // specialization for method
  793. } // namespace transform
  794. } // namespace mindspore
  795. #endif // TRANSFORM_OP_ADAPTER_H_