You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

abstract_value.cc 32 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015
  1. /**
  2. * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
  3. *
  4. * Copyright 2019 Huawei Technologies Co., Ltd
  5. *
  6. * Licensed under the Apache License, Version 2.0 (the "License");
  7. * you may not use this file except in compliance with the License.
  8. * You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. #include "pipeline/static_analysis/abstract_value.h"
  19. #include <algorithm>
  20. #include "utils/symbolic.h"
  21. #include "pipeline/static_analysis/static_analysis.h"
  22. #include "pipeline/static_analysis/utils.h"
  23. namespace mindspore {
  24. namespace abstract {
  25. bool AbstractBase::operator==(const AbstractBase &other) const {
  26. if (tid() != other.tid()) {
  27. return false;
  28. }
  29. if (value_ == nullptr || other.value_ == nullptr) {
  30. MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
  31. << this->ToString() << ", other: " << other.ToString();
  32. }
  33. bool value_equal = *value_ == *other.value_;
  34. bool type_equal = *type_ == *other.type_;
  35. bool shape_equal = *shape_ == *other.shape_;
  36. return value_equal && type_equal && shape_equal;
  37. }
  38. ValuePtr AbstractBase::BuildValue() const {
  39. if (value_ == nullptr) {
  40. return RealBuildValue();
  41. }
  42. return value_;
  43. }
  44. AbstractBasePtr AbstractBase::Broaden() const {
  45. AbstractBasePtr clone = Clone();
  46. clone->set_value(kAnyValue);
  47. return clone;
  48. }
  49. std::string AbstractBase::ToString() const {
  50. std::ostringstream buffer;
  51. std::string value = std::string("value is null");
  52. if (value_ != nullptr) {
  53. value = value_->ToString();
  54. }
  55. MS_EXCEPTION_IF_NULL(type_);
  56. MS_EXCEPTION_IF_NULL(shape_);
  57. buffer << type_name() << "("
  58. << "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString() << ")";
  59. return buffer.str();
  60. }
  61. AbstractBasePtr AbstractScalar::Broaden() const {
  62. AbstractBasePtr clone = Clone();
  63. MS_EXCEPTION_IF_NULL(clone);
  64. auto value_track = clone->GetValueTrack();
  65. MS_EXCEPTION_IF_NULL(value_track);
  66. if (value_track->isa<SymbolicKeyInstance>()) {
  67. return clone;
  68. }
  69. return AbstractBase::Broaden();
  70. }
  71. AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
  72. MS_EXCEPTION_IF_NULL(other);
  73. if (*this == *other) {
  74. return shared_from_base<AbstractBase>();
  75. }
  76. auto value_self = GetValueTrack();
  77. MS_EXCEPTION_IF_NULL(value_self);
  78. ValuePtr res_value = ValueJoin(value_self, other->GetValueTrack());
  79. TypePtr res_type = TypeJoin(GetTypeTrack(), other->GetTypeTrack());
  80. if (res_value == value_self) {
  81. return shared_from_base<AbstractBase>();
  82. }
  83. return std::make_shared<AbstractScalar>(res_value, res_type);
  84. }
  85. AbstractBasePtr AbstractType::Clone() const {
  86. ValuePtr value_self = GetValueTrack();
  87. if (value_self == nullptr || !value_self->isa<Type>()) {
  88. return nullptr;
  89. }
  90. TypePtr type_self = value_self->cast<TypePtr>();
  91. return std::make_shared<AbstractType>(type_self->Clone());
  92. }
  93. bool AbstractType::operator==(const AbstractBase &other) const {
  94. if (tid() != other.tid()) {
  95. return false;
  96. }
  97. // Have to compare TypePtr with value;
  98. ValuePtr value_self = GetValueTrack();
  99. ValuePtr value_other = other.GetValueTrack();
  100. if (value_self == nullptr || value_other == nullptr) {
  101. MS_LOG(EXCEPTION) << "AbstractType value should not be nullptr. this: " << this->ToString()
  102. << ", other: " << other.ToString();
  103. }
  104. if (!value_self->isa<Type>() || !value_other->isa<Type>()) {
  105. return false;
  106. }
  107. TypePtr type_self = value_self->cast<TypePtr>();
  108. TypePtr type_other = value_other->cast<TypePtr>();
  109. bool value_equal = *type_self == *type_other;
  110. return value_equal;
  111. }
  112. std::string AbstractType::ToString() const {
  113. std::ostringstream buffer;
  114. ValuePtr value_self = GetValueTrack();
  115. if (value_self == nullptr) {
  116. buffer << "AbstractType value: nullptr";
  117. return buffer.str();
  118. }
  119. if (!value_self->isa<Type>()) {
  120. buffer << type_name() << "(Value: nullptr)";
  121. return buffer.str();
  122. }
  123. TypePtr type_self = value_self->cast<TypePtr>();
  124. MS_EXCEPTION_IF_NULL(type_self);
  125. buffer << type_name() << "("
  126. << "Value: " << type_self->ToString() << ")";
  127. return buffer.str();
  128. }
  129. std::string AbstractError::ToString() const {
  130. std::ostringstream buffer;
  131. auto value_track = GetValueTrack();
  132. MS_EXCEPTION_IF_NULL(value_track);
  133. buffer << type_name() << "("
  134. << "Value: " << value_track->ToString() << ", Node: " << node_->DebugString() << ")";
  135. return buffer.str();
  136. }
  137. AbstractBasePtr AbstractFunction::Join(const AbstractBasePtr &other) {
  138. MS_EXCEPTION_IF_NULL(other);
  139. auto other_func = dyn_cast<AbstractFunction>(other);
  140. if (other_func == nullptr) {
  141. MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
  142. }
  143. return Join(other_func);
  144. }
  145. bool AbstractFunction::operator==(const AbstractBase &other) const {
  146. if (!other.isa<AbstractFunction>()) {
  147. return false;
  148. }
  149. const auto &other_func = static_cast<const AbstractFunction &>(other);
  150. bool value_equal = (*this == other_func);
  151. return value_equal;
  152. }
  153. const AbstractBasePtr AbstractSequeue::operator[](const std::size_t &dim) const {
  154. if (dim >= size()) {
  155. MS_LOG(EXCEPTION) << "Index [" << dim << "] Out of the size [" << size() << "] of the list.";
  156. }
  157. return elements_[dim];
  158. }
  159. std::string AbstractSequeue::ToString() const {
  160. std::ostringstream buffer;
  161. int i = 0;
  162. for (const auto &ele : elements_) {
  163. MS_EXCEPTION_IF_NULL(ele);
  164. buffer << "element[" << i << "]: " << ele->ToString() << ",";
  165. i++;
  166. }
  167. return buffer.str();
  168. }
  169. TypePtrList AbstractSequeue::ElementsType() const {
  170. TypePtrList element_type_list;
  171. for (const auto &ele : elements_) {
  172. MS_EXCEPTION_IF_NULL(ele);
  173. TypePtr element_type = ele->BuildType();
  174. element_type_list.push_back(element_type);
  175. }
  176. return element_type_list;
  177. }
  178. BaseShapePtrList AbstractSequeue::ElementsShape() const {
  179. BaseShapePtrList element_shape_list;
  180. for (const auto &ele : elements_) {
  181. MS_EXCEPTION_IF_NULL(ele);
  182. BaseShapePtr element_shape = ele->BuildShape();
  183. element_shape_list.push_back(element_shape);
  184. }
  185. return element_shape_list;
  186. }
  187. AbstractBasePtrList AbstractSequeue::ElementsClone() const {
  188. AbstractBasePtrList ele_list;
  189. for (const auto &ele : elements_) {
  190. MS_EXCEPTION_IF_NULL(ele);
  191. AbstractBasePtr clone = ele->Clone();
  192. ele_list.push_back(clone);
  193. }
  194. return ele_list;
  195. }
  196. AbstractBasePtrList AbstractSequeue::ElementsBroaden() const {
  197. AbstractBasePtrList ele_list;
  198. for (const auto &ele : elements_) {
  199. MS_EXCEPTION_IF_NULL(ele);
  200. AbstractBasePtr broadend = ele->Broaden();
  201. ele_list.push_back(broadend);
  202. }
  203. return ele_list;
  204. }
  205. template <typename T>
  206. ValuePtr AbstractSequeue::ElementsBuildValue() const {
  207. std::vector<ValuePtr> element_value_list;
  208. for (const auto &ele : elements_) {
  209. ValuePtr element_value = ele->BuildValue();
  210. if (element_value->isa<AnyValue>()) {
  211. return kAnyValue;
  212. }
  213. element_value_list.push_back(element_value);
  214. }
  215. return std::make_shared<T>(element_value_list);
  216. }
  217. template ValuePtr AbstractSequeue::ElementsBuildValue<ValueTuple>() const;
  218. template ValuePtr AbstractSequeue::ElementsBuildValue<ValueList>() const;
  219. template <typename T>
  220. AbstractBasePtr AbstractSequeue::ElementsJoin(const AbstractBasePtr &other) {
  221. auto other_sequeue = dyn_cast<T>(other);
  222. if (other_sequeue == nullptr) {
  223. MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
  224. }
  225. auto joined_list = AbstractJoin(elements_, other_sequeue->elements_);
  226. bool changes = false;
  227. for (std::size_t i = 0; i < elements_.size(); i++) {
  228. if (elements_[i] != joined_list[i]) {
  229. changes = true;
  230. break;
  231. }
  232. }
  233. if (!changes) {
  234. return shared_from_base<AbstractBase>();
  235. }
  236. return std::make_shared<T>(joined_list);
  237. }
  238. template AbstractBasePtr AbstractSequeue::ElementsJoin<AbstractList>(const AbstractBasePtr &);
  239. template AbstractBasePtr AbstractSequeue::ElementsJoin<AbstractTuple>(const AbstractBasePtr &);
  240. std::size_t AbstractSequeue::hash() const {
  241. std::size_t hash_sum = hash_combine(tid(), std::hash<size_t>{}(elements_.size()));
  242. // Hashing all elements is costly, so only take at most 4 elements into account based on
  243. // some experiments.
  244. for (size_t i = 0; (i < elements_.size()) && (i < 4); i++) {
  245. hash_sum = hash_combine(hash_sum, elements_[i]->hash());
  246. }
  247. return hash_sum;
  248. }
  249. bool AbstractTuple::operator==(const AbstractTuple &other) const {
  250. if (&other == this) {
  251. return true;
  252. }
  253. if (elements_.size() != other.elements_.size()) {
  254. return false;
  255. }
  256. for (size_t i = 0; i < elements_.size(); i++) {
  257. if (!(*(elements_[i]) == *(other.elements_[i]))) {
  258. return false;
  259. }
  260. }
  261. return true;
  262. }
  263. bool AbstractTuple::operator==(const AbstractBase &other) const {
  264. if (&other == this) {
  265. return true;
  266. }
  267. if (other.isa<AbstractTuple>()) {
  268. auto other_tuple = static_cast<const AbstractTuple *>(&other);
  269. return *this == *other_tuple;
  270. }
  271. return false;
  272. }
  273. bool AbstractList::operator==(const AbstractList &other) const {
  274. if (&other == this) {
  275. return true;
  276. }
  277. if (elements_.size() != other.elements_.size()) {
  278. return false;
  279. }
  280. for (size_t i = 0; i < elements_.size(); i++) {
  281. if (!(*(elements_[i]) == *(other.elements_[i]))) {
  282. return false;
  283. }
  284. }
  285. return true;
  286. }
  287. bool AbstractList::operator==(const AbstractBase &other) const {
  288. if (&other == this) {
  289. return true;
  290. }
  291. if (other.isa<AbstractList>()) {
  292. auto other_list = static_cast<const AbstractList *>(&other);
  293. return *this == *other_list;
  294. }
  295. return false;
  296. }
  297. TypePtr AbstractSlice::BuildType() const {
  298. MS_EXCEPTION_IF_NULL(start_);
  299. MS_EXCEPTION_IF_NULL(stop_);
  300. MS_EXCEPTION_IF_NULL(step_);
  301. TypePtr start = start_->BuildType();
  302. TypePtr stop = stop_->BuildType();
  303. TypePtr step = step_->BuildType();
  304. return std::make_shared<Slice>(start, stop, step);
  305. }
  306. bool AbstractSlice::operator==(const AbstractSlice &other) const {
  307. if (&other == this) {
  308. return true;
  309. }
  310. return (*start_ == *other.start_ && *stop_ == *other.stop_ && *step_ == *other.step_);
  311. }
  312. bool AbstractSlice::operator==(const AbstractBase &other) const {
  313. if (&other == this) {
  314. return true;
  315. }
  316. if (!other.isa<AbstractSlice>()) {
  317. return false;
  318. }
  319. auto other_slice = static_cast<const AbstractSlice *>(&other);
  320. return *this == *other_slice;
  321. }
  322. AbstractBasePtr AbstractSlice::Clone() const {
  323. MS_EXCEPTION_IF_NULL(start_);
  324. MS_EXCEPTION_IF_NULL(stop_);
  325. MS_EXCEPTION_IF_NULL(step_);
  326. AbstractBasePtr start = start_->Clone();
  327. AbstractBasePtr stop = stop_->Clone();
  328. AbstractBasePtr step = step_->Clone();
  329. return std::make_shared<AbstractSlice>(start, stop, step);
  330. }
  331. AbstractBasePtr AbstractSlice::Broaden() const {
  332. MS_EXCEPTION_IF_NULL(start_);
  333. MS_EXCEPTION_IF_NULL(stop_);
  334. MS_EXCEPTION_IF_NULL(step_);
  335. AbstractBasePtr start = start_->Broaden();
  336. AbstractBasePtr stop = stop_->Broaden();
  337. AbstractBasePtr step = step_->Broaden();
  338. return std::make_shared<AbstractSlice>(start, stop, step);
  339. }
  340. std::string AbstractSlice::ToString() const {
  341. std::ostringstream buffer;
  342. buffer << type_name() << "[";
  343. MS_EXCEPTION_IF_NULL(start_);
  344. buffer << start_->ToString() << " : ";
  345. MS_EXCEPTION_IF_NULL(stop_);
  346. buffer << stop_->ToString() << " : ";
  347. MS_EXCEPTION_IF_NULL(step_);
  348. buffer << step_->ToString();
  349. buffer << "]";
  350. return buffer.str();
  351. }
  352. ValuePtr AbstractSlice::RealBuildValue() const {
  353. MS_EXCEPTION_IF_NULL(start_);
  354. MS_EXCEPTION_IF_NULL(stop_);
  355. MS_EXCEPTION_IF_NULL(step_);
  356. ValuePtr start = start_->BuildValue();
  357. ValuePtr stop = stop_->BuildValue();
  358. ValuePtr step = step_->BuildValue();
  359. if (start->isa<AnyValue>() || stop->isa<AnyValue>() || step->isa<AnyValue>()) {
  360. return kAnyValue;
  361. }
  362. return std::make_shared<ValueSlice>(start, stop, step);
  363. }
  364. std::size_t AbstractSlice::hash() const {
  365. MS_EXCEPTION_IF_NULL(start_);
  366. MS_EXCEPTION_IF_NULL(stop_);
  367. MS_EXCEPTION_IF_NULL(step_);
  368. return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
  369. }
  370. TypePtr AbstractTensor::BuildType() const {
  371. MS_EXCEPTION_IF_NULL(element_);
  372. TypePtr element_type = element_->BuildType();
  373. return std::make_shared<TensorType>(element_type);
  374. }
  375. BaseShapePtr AbstractTensor::BuildShape() const {
  376. auto shape = GetShapeTrack();
  377. // Guard from using set_shape(nullptr)
  378. if (shape == nullptr) {
  379. return kNoShape;
  380. }
  381. return shape;
  382. }
  383. AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
  384. auto other_tensor = dyn_cast<AbstractTensor>(other);
  385. if (other_tensor == nullptr) {
  386. MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
  387. }
  388. auto element = element_->Join(other_tensor->element_);
  389. auto shape = ShapeJoin(this->shape(), other_tensor->shape());
  390. return std::make_shared<AbstractTensor>(element, shape);
  391. }
  392. bool AbstractTensor::operator==(const AbstractTensor &other) const {
  393. if (&other == this) {
  394. return true;
  395. }
  396. auto v1 = GetValueTrack();
  397. auto v2 = other.GetValueTrack();
  398. if (v1 == nullptr || v2 == nullptr) {
  399. MS_LOG(EXCEPTION) << "The value of AbstractTensor is nullptr";
  400. }
  401. bool is_value_equal = (v1 == v2);
  402. if (v1->isa<AnyValue>() && v2->isa<AnyValue>()) {
  403. is_value_equal = true;
  404. }
  405. return (*element_ == *other.element_) && (*shape() == *other.shape()) && is_value_equal;
  406. }
  407. bool AbstractTensor::operator==(const AbstractBase &other) const {
  408. if (&other == this) {
  409. return true;
  410. }
  411. if (other.isa<AbstractTensor>()) {
  412. auto other_tensor = static_cast<const AbstractTensor *>(&other);
  413. return *this == *other_tensor;
  414. } else {
  415. return false;
  416. }
  417. }
  418. AbstractBasePtr AbstractTensor::Clone() const {
  419. MS_EXCEPTION_IF_NULL(element_);
  420. auto clone = std::make_shared<AbstractTensor>(element_->Clone());
  421. ShapePtr shp = shape();
  422. clone->set_shape(shp->Clone());
  423. clone->set_value(GetValueTrack());
  424. return clone;
  425. }
  426. AbstractBasePtr AbstractTensor::Broaden() const {
  427. MS_EXCEPTION_IF_NULL(element_);
  428. auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
  429. auto shp = shape();
  430. broaden->set_shape(shp->Clone());
  431. broaden->set_value(kAnyValue);
  432. return broaden;
  433. }
  434. AbstractBasePtr AbstractTensor::BroadenWithShape() const {
  435. MS_EXCEPTION_IF_NULL(element_);
  436. auto broaden = std::make_shared<AbstractTensor>(element_->Broaden());
  437. auto shp = shape()->Clone();
  438. shp->Broaden();
  439. broaden->set_shape(shp);
  440. broaden->set_value(kAnyValue);
  441. return broaden;
  442. }
  443. ShapePtr AbstractTensor::shape() const {
  444. auto shp = dyn_cast<Shape>(GetShapeTrack());
  445. if (shp == nullptr) {
  446. MS_LOG(EXCEPTION) << "Tensor should have a shape.";
  447. }
  448. return shp;
  449. }
  450. std::string AbstractTensor::ToString() const {
  451. std::ostringstream buffer;
  452. BaseShapePtr shape_track = GetShapeTrack();
  453. MS_EXCEPTION_IF_NULL(shape_track);
  454. MS_EXCEPTION_IF_NULL(element_);
  455. auto value_track = GetValueTrack();
  456. MS_EXCEPTION_IF_NULL(value_track);
  457. buffer << type_name() << "("
  458. << "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
  459. << ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")";
  460. return buffer.str();
  461. }
  462. TypePtr AbstractDictionary::BuildType() const {
  463. std::vector<std::pair<std::string, TypePtr>> key_values;
  464. for (const auto &item : key_values_) {
  465. MS_EXCEPTION_IF_NULL(item.second);
  466. TypePtr type = item.second->BuildType();
  467. key_values.emplace_back(item.first, type);
  468. }
  469. return std::make_shared<Dictionary>(key_values);
  470. }
  471. bool AbstractDictionary::operator==(const AbstractDictionary &other) const {
  472. if (key_values_.size() != other.key_values_.size()) {
  473. return false;
  474. }
  475. for (size_t index = 0; index < key_values_.size(); index++) {
  476. if (key_values_[index].first != other.key_values_[index].first) {
  477. return false;
  478. }
  479. if (!(*key_values_[index].second == *other.key_values_[index].second)) {
  480. return false;
  481. }
  482. }
  483. return true;
  484. }
  485. bool AbstractDictionary::operator==(const AbstractBase &other) const {
  486. if (&other == this) {
  487. return true;
  488. }
  489. if (other.isa<AbstractDictionary>()) {
  490. auto other_class = static_cast<const AbstractDictionary *>(&other);
  491. return *this == *other_class;
  492. }
  493. return false;
  494. }
  495. AbstractBasePtr AbstractDictionary::Clone() const {
  496. std::vector<AbstractAttribute> kv;
  497. (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
  498. [](const AbstractAttribute &item) {
  499. MS_EXCEPTION_IF_NULL(item.second);
  500. return std::make_pair(item.first, item.second->Clone());
  501. });
  502. return std::make_shared<AbstractDictionary>(kv);
  503. }
  504. AbstractBasePtr AbstractDictionary::Broaden() const {
  505. std::vector<AbstractAttribute> kv;
  506. (void)std::transform(key_values_.begin(), key_values_.end(), std::back_inserter(kv),
  507. [](const AbstractAttribute &item) {
  508. MS_EXCEPTION_IF_NULL(item.second);
  509. return std::make_pair(item.first, item.second->Broaden());
  510. });
  511. return std::make_shared<AbstractDictionary>(kv);
  512. }
  513. std::string AbstractDictionary::ToString() const {
  514. std::ostringstream buffer;
  515. buffer << type_name() << "{ ";
  516. for (const auto &kv : key_values_) {
  517. MS_EXCEPTION_IF_NULL(kv.second);
  518. buffer << "(" << kv.first << ": " << kv.second->ToString() << ") ";
  519. }
  520. buffer << "}";
  521. return buffer.str();
  522. }
  523. std::size_t AbstractDictionary::hash() const {
  524. std::size_t hash_sum = std::accumulate(key_values_.begin(), key_values_.end(), tid(),
  525. [](std::size_t hash_sum, const AbstractAttribute &item) {
  526. hash_sum = hash_combine(hash_sum, std::hash<std::string>()(item.first));
  527. MS_EXCEPTION_IF_NULL(item.second);
  528. hash_sum = hash_combine(hash_sum, item.second->hash());
  529. return hash_sum;
  530. });
  531. return hash_sum;
  532. }
  533. ValuePtr AbstractDictionary::RealBuildValue() const {
  534. std::vector<std::pair<std::string, ValuePtr>> key_values;
  535. for (const auto &item : key_values_) {
  536. MS_EXCEPTION_IF_NULL(item.second);
  537. auto element_value = item.second->BuildValue();
  538. MS_EXCEPTION_IF_NULL(element_value);
  539. if (element_value->isa<AnyValue>()) {
  540. return kAnyValue;
  541. }
  542. key_values.emplace_back(item.first, element_value);
  543. }
  544. return std::make_shared<ValueDictionary>(key_values);
  545. }
  546. TypePtr AbstractClass::BuildType() const {
  547. ClassAttrVector attributes_type;
  548. for (auto attr : attributes_) {
  549. MS_EXCEPTION_IF_NULL(attr.second);
  550. TypePtr type = attr.second->BuildType();
  551. std::pair<std::string, TypePtr> elem(attr.first, type);
  552. attributes_type.push_back(elem);
  553. }
  554. return std::make_shared<Class>(tag_, attributes_type, methods_);
  555. }
  556. bool AbstractClass::operator==(const AbstractClass &other) const {
  557. if (!(tag_ == other.tag_)) {
  558. return false;
  559. }
  560. if (attributes_.size() != other.attributes_.size()) {
  561. return false;
  562. }
  563. for (size_t i = 0; i < attributes_.size(); i++) {
  564. MS_EXCEPTION_IF_NULL(attributes_[i].second);
  565. MS_EXCEPTION_IF_NULL(other.attributes_[i].second);
  566. if (!(*attributes_[i].second == *other.attributes_[i].second)) {
  567. MS_LOG(DEBUG) << "attr " << attributes_[i].first << " not equal, arg1:" << attributes_[i].second->ToString()
  568. << " arg2:" << other.attributes_[i].second->ToString();
  569. return false;
  570. }
  571. }
  572. // method compare;
  573. if (methods_.size() != other.methods_.size()) {
  574. return false;
  575. }
  576. for (const auto &iter : methods_) {
  577. auto iter_other = other.methods_.find(iter.first);
  578. if (iter_other == other.methods_.end()) {
  579. return false;
  580. }
  581. if (!(*iter.second == *iter_other->second)) {
  582. return false;
  583. }
  584. }
  585. return true;
  586. }
  587. bool AbstractClass::operator==(const AbstractBase &other) const {
  588. if (other.isa<AbstractClass>()) {
  589. auto other_class = static_cast<const AbstractClass *>(&other);
  590. return *this == *other_class;
  591. }
  592. return false;
  593. }
  594. AbstractBasePtr AbstractClass::GetAttribute(const std::string &name) {
  595. auto it = std::find_if(attributes_.begin(), attributes_.end(),
  596. [name](const AbstractAttribute &pair) -> bool { return pair.first == name; });
  597. if (it != attributes_.end()) {
  598. return it->second;
  599. }
  600. return nullptr;
  601. }
  602. ValuePtr AbstractClass::GetMethod(const std::string &name) {
  603. auto method_pair = methods_.find(name);
  604. if (method_pair != methods_.end()) {
  605. return method_pair->second;
  606. }
  607. return kAnyValue;
  608. }
  609. AbstractBasePtr AbstractClass::Clone() const {
  610. std::vector<AbstractAttribute> attributes_clone;
  611. for (auto attr : attributes_) {
  612. MS_EXCEPTION_IF_NULL(attr.second);
  613. AbstractBasePtr clone = attr.second->Clone();
  614. AbstractAttribute elem(attr.first, clone);
  615. attributes_clone.push_back(elem);
  616. }
  617. return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
  618. }
  619. AbstractBasePtr AbstractClass::Broaden() const {
  620. std::vector<AbstractAttribute> attributes_clone;
  621. for (auto attr : attributes_) {
  622. MS_EXCEPTION_IF_NULL(attr.second);
  623. AbstractBasePtr clone = attr.second->Broaden();
  624. AbstractAttribute elem(attr.first, clone);
  625. attributes_clone.push_back(elem);
  626. }
  627. return std::make_shared<AbstractClass>(tag_, attributes_clone, methods_);
  628. }
  629. std::string AbstractClass::ToString() const {
  630. std::ostringstream buffer;
  631. buffer << type_name() << "(tag: " << tag_ << ") attrs:(";
  632. bool append_comma = false;
  633. for (const auto &attr : attributes_) {
  634. if (append_comma) {
  635. buffer << ", ";
  636. } else {
  637. append_comma = true;
  638. }
  639. MS_EXCEPTION_IF_NULL(attr.second);
  640. buffer << attr.first << ":" << attr.second->ToString();
  641. }
  642. buffer << ") method:(";
  643. append_comma = false;
  644. for (const auto &iter : methods_) {
  645. if (append_comma) {
  646. buffer << ", ";
  647. } else {
  648. append_comma = true;
  649. }
  650. MS_EXCEPTION_IF_NULL(iter.second);
  651. buffer << iter.first << ":" << iter.second->ToString();
  652. }
  653. buffer << ")";
  654. return buffer.str();
  655. }
  656. std::size_t AbstractClass::hash() const {
  657. std::size_t hash_sum = std::accumulate(attributes_.begin(), attributes_.end(), hash_combine(tid(), tag_.hash()),
  658. [](std::size_t hash_sum, const AbstractAttribute &item) {
  659. MS_EXCEPTION_IF_NULL(item.second);
  660. return hash_combine(hash_sum, item.second->hash());
  661. });
  662. return hash_sum;
  663. }
  664. ValuePtr AbstractClass::RealBuildValue() const {
  665. auto cls = BuildType()->cast<ClassPtr>();
  666. std::unordered_map<std::string, ValuePtr> attributes_value_map;
  667. for (const auto &attr : attributes_) {
  668. MS_EXCEPTION_IF_NULL(attr.second);
  669. ValuePtr _value = attr.second->BuildValue();
  670. if (_value->isa<AnyValue>()) {
  671. return kAnyValue;
  672. }
  673. attributes_value_map[attr.first] = _value;
  674. }
  675. cls->set_value(attributes_value_map);
  676. return cls;
  677. }
  678. TypePtr AbstractJTagged::BuildType() const {
  679. MS_EXCEPTION_IF_NULL(element_);
  680. TypePtr subtype = element_->BuildType();
  681. return std::make_shared<JTagged>(subtype);
  682. }
  683. AbstractBasePtr AbstractJTagged::Join(const AbstractBasePtr &other) {
  684. auto other_jtagged = dyn_cast<AbstractJTagged>(other);
  685. if (other_jtagged == nullptr) {
  686. MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
  687. }
  688. auto joined_elem = element_->Join(other_jtagged->element_);
  689. return std::make_shared<AbstractJTagged>(joined_elem);
  690. }
  691. bool AbstractJTagged::operator==(const AbstractJTagged &other) const {
  692. MS_EXCEPTION_IF_NULL(element_);
  693. MS_EXCEPTION_IF_NULL(other.element_);
  694. return (*element_ == *other.element_);
  695. }
  696. bool AbstractJTagged::operator==(const AbstractBase &other) const {
  697. if (other.isa<AbstractJTagged>()) {
  698. auto other_jtagged = static_cast<const AbstractJTagged *>(&other);
  699. return *this == *other_jtagged;
  700. }
  701. return false;
  702. }
  703. std::string AbstractJTagged::ToString() const {
  704. std::ostringstream buffer;
  705. MS_EXCEPTION_IF_NULL(element_);
  706. buffer << type_name() << "("
  707. << "element: " << element_->ToString() << ")";
  708. return buffer.str();
  709. }
  710. TypePtr AbstractRef::BuildType() const {
  711. TypePtr subtype = ref_->BuildType();
  712. TypePtr subtype_origin = ref_origin_->BuildType();
  713. return std::make_shared<RefType>(subtype, subtype_origin);
  714. }
  715. bool AbstractRef::operator==(const AbstractRef &other) const {
  716. return (*ref_ == *other.ref_) && (*ref_key_ == *other.ref_key_) && (*ref_origin_ == *other.ref_origin_);
  717. }
  718. bool AbstractRef::operator==(const AbstractBase &other) const {
  719. if (other.isa<AbstractRef>()) {
  720. auto other_conf = static_cast<const AbstractRef *>(&other);
  721. return *this == *other_conf;
  722. }
  723. return false;
  724. }
  725. std::string AbstractRef::ToString() const {
  726. std::ostringstream buffer;
  727. buffer << type_name() << "("
  728. << "key: " << ref_key_->ToString() << "ref_value: " << ref_->ToString()
  729. << "origin_value: " << ref_origin_->ToString();
  730. auto value = GetValueTrack();
  731. if (value) {
  732. buffer << ", value: " << value->ToString();
  733. }
  734. buffer << ")";
  735. return buffer.str();
  736. }
  737. bool AbstractNone::operator==(const AbstractNone &) const { return true; }
  738. bool AbstractNone::operator==(const AbstractBase &other) const {
  739. if (other.isa<AbstractNone>()) {
  740. auto other_none = static_cast<const AbstractNone *>(&other);
  741. return *this == *other_none;
  742. }
  743. return false;
  744. }
  745. std::string AbstractNone::ToString() const {
  746. std::ostringstream buffer;
  747. buffer << type_name() << "(Value: None)";
  748. return buffer.str();
  749. }
  750. ValuePtr AbstractNone::RealBuildValue() const { return kNone; }
  751. bool AbstractRefKey::operator==(const AbstractRefKey &other) const {
  752. ValuePtr value_self = GetValueTrack();
  753. ValuePtr value_other = other.GetValueTrack();
  754. if (value_self != nullptr && value_other != nullptr) {
  755. if (value_self->isa<AnyValue>() && value_other->isa<AnyValue>()) {
  756. return true;
  757. }
  758. if (!value_self->isa<RefKey>() || !value_other->isa<RefKey>()) {
  759. return false;
  760. }
  761. RefKeyPtr type_self = value_self->cast<RefKeyPtr>();
  762. RefKeyPtr type_other = value_other->cast<RefKeyPtr>();
  763. return *type_self == *type_other;
  764. } else if (value_self != nullptr || value_other != nullptr) {
  765. return false;
  766. }
  767. return true;
  768. }
  769. bool AbstractRefKey::operator==(const AbstractBase &other) const {
  770. if (other.isa<AbstractRefKey>()) {
  771. auto other_confkey = static_cast<const AbstractRefKey *>(&other);
  772. return *this == *other_confkey;
  773. } else {
  774. return false;
  775. }
  776. }
  777. std::string AbstractRefKey::ToString() const {
  778. std::ostringstream buffer;
  779. buffer << type_name();
  780. auto value = GetValueTrack();
  781. if (value) {
  782. buffer << "(value: " << value->ToString() << ")";
  783. }
  784. return buffer.str();
  785. }
  786. bool AbstractNull::operator==(const AbstractNull &) const { return true; }
  787. bool AbstractNull::operator==(const AbstractBase &other) const {
  788. if (&other == this) {
  789. return true;
  790. }
  791. if (other.isa<AbstractNull>()) {
  792. auto other_none = static_cast<const AbstractNull *>(&other);
  793. return *this == *other_none;
  794. } else {
  795. return false;
  796. }
  797. }
  798. std::string AbstractNull::ToString() const {
  799. std::ostringstream buffer;
  800. buffer << type_name() << "(Value: Null)";
  801. return buffer.str();
  802. }
  803. bool AbstractEllipsis::operator==(const AbstractEllipsis &) const { return true; }
  804. bool AbstractEllipsis::operator==(const AbstractBase &other) const {
  805. if (&other == this) {
  806. return true;
  807. }
  808. if (other.isa<AbstractEllipsis>()) {
  809. auto other_none = static_cast<const AbstractEllipsis *>(&other);
  810. return *this == *other_none;
  811. } else {
  812. return false;
  813. }
  814. }
  815. std::string AbstractEllipsis::ToString() const {
  816. std::ostringstream buffer;
  817. buffer << type_name() << "(Value: Ellipsis)";
  818. return buffer.str();
  819. }
  820. TypePtr AbstractKeywordArg::BuildType() const {
  821. MS_EXCEPTION_IF_NULL(arg_value_);
  822. TypePtr type = arg_value_->BuildType();
  823. return std::make_shared<Keyword>(arg_name_, type);
  824. }
  825. AbstractBasePtr AbstractKeywordArg::Clone() const {
  826. MS_EXCEPTION_IF_NULL(arg_value_);
  827. return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Clone());
  828. }
  829. AbstractBasePtr AbstractKeywordArg::Broaden() const {
  830. MS_EXCEPTION_IF_NULL(arg_value_);
  831. return std::make_shared<AbstractKeywordArg>(arg_name_, arg_value_->Broaden());
  832. }
  833. std::size_t AbstractKeywordArg::hash() const {
  834. MS_EXCEPTION_IF_NULL(arg_value_);
  835. return hash_combine({tid(), std::hash<std::string>{}(arg_name_), arg_value_->hash()});
  836. }
  837. std::string AbstractKeywordArg::ToString() const {
  838. std::ostringstream buffer;
  839. MS_EXCEPTION_IF_NULL(arg_value_);
  840. buffer << type_name() << "(";
  841. buffer << "key : " << arg_name_;
  842. buffer << "value : " << arg_value_->ToString();
  843. buffer << ")";
  844. return buffer.str();
  845. }
  846. bool AbstractKeywordArg::operator==(const AbstractBase &other) const {
  847. if (&other == this) {
  848. return true;
  849. }
  850. if (other.isa<AbstractKeywordArg>()) {
  851. auto other_tuple = static_cast<const AbstractKeywordArg *>(&other);
  852. return *this == *other_tuple;
  853. }
  854. return false;
  855. }
  856. bool AbstractKeywordArg::operator==(const AbstractKeywordArg &other) const {
  857. if (&other == this) {
  858. return true;
  859. }
  860. MS_EXCEPTION_IF_NULL(arg_value_);
  861. MS_EXCEPTION_IF_NULL(other.arg_value_);
  862. return other.arg_name_ == arg_name_ && *other.arg_value_ == *arg_value_;
  863. }
  864. ValuePtr AbstractKeywordArg::RealBuildValue() const {
  865. MS_EXCEPTION_IF_NULL(arg_value_);
  866. ValuePtr value = arg_value_->BuildValue();
  867. MS_EXCEPTION_IF_NULL(value);
  868. if (value->isa<AnyValue>()) {
  869. return kAnyValue;
  870. }
  871. return std::make_shared<KeywordArg>(arg_name_, value);
  872. }
  873. std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list) {
  874. std::size_t hash_value = 0;
  875. // Hashing all elements is costly, so only take at most 4 elements into account based on
  876. // some experiments.
  877. for (size_t i = 0; (i < args_spec_list.size()) && (i < 4); i++) {
  878. MS_EXCEPTION_IF_NULL(args_spec_list[i]);
  879. hash_value = hash_combine(hash_value, args_spec_list[i]->hash());
  880. }
  881. return hash_value;
  882. }
  883. bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) {
  884. if (lhs.size() != rhs.size()) {
  885. return false;
  886. }
  887. std::size_t size = lhs.size();
  888. for (std::size_t i = 0; i < size; i++) {
  889. MS_EXCEPTION_IF_NULL(lhs[i]);
  890. MS_EXCEPTION_IF_NULL(rhs[i]);
  891. if (!(*lhs[i] == *rhs[i])) {
  892. return false;
  893. }
  894. }
  895. return true;
  896. }
  897. std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &args_spec_list) const {
  898. return AbstractBasePtrListHash(args_spec_list);
  899. }
  900. bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
  901. return AbstractBasePtrListDeepEqual(lhs, rhs);
  902. }
  903. } // namespace abstract
  904. } // namespace mindspore