You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

op_adapter_util.cc 10 kB

6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "transform/op_adapter_util.h"
  17. #include <string>
  18. #include <vector>
  19. #include <algorithm>
  20. #include "utils/utils.h"
  21. #include "transform/op_adapter_base.h"
  22. namespace mindspore {
  23. namespace transform {
  24. GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<mindspore::tensor::Tensor> &) {
  25. // To-DO the format may read from ME tensor
  26. MS_EXCEPTION_IF_NULL(value);
  27. auto me_tensor = value->cast<MeTensorPtr>();
  28. auto ge_tensor = TransformUtil::ConvertTensor(me_tensor, kOpFormat_NCHW);
  29. return ge_tensor == nullptr ? GeTensor() : *ge_tensor;
  30. }
  31. std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &name,
  32. const AnyTraits<std::vector<int64_t>>) {
  33. MS_EXCEPTION_IF_NULL(value);
  34. std::vector<int64_t> list;
  35. if (name == "pad") {
  36. if (!value->isa<ValueSequeue>()) {
  37. MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name();
  38. }
  39. auto vec = value->cast<ValueSequeuePtr>();
  40. list.resize(vec->value().size() + 2);
  41. list[0] = 1;
  42. list[1] = 1;
  43. (void)std::transform(vec->value().begin(), vec->value().end(), list.begin() + 2,
  44. [](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int>(val)); });
  45. } else {
  46. int64_t data = GetValue<int>(value);
  47. int size = 2; // 2 int in list
  48. list = TransformUtil::ConvertIntToList(data, size);
  49. }
  50. return list;
  51. }
  52. std::string ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<int64_t>>, const AnyTraits<std::string>) {
  53. MS_EXCEPTION_IF_NULL(value);
  54. auto vec = value->cast<ValueTuplePtr>();
  55. if (nullptr == vec) {
  56. MS_LOG(EXCEPTION) << "not ValueTuplePtr";
  57. }
  58. std::ostringstream buffer;
  59. int i = 0;
  60. for (auto &it : vec->value()) {
  61. if (i != 0) {
  62. buffer << ",";
  63. }
  64. buffer << GetValue<int>(it);
  65. i++;
  66. }
  67. return buffer.str();
  68. }
  69. std::vector<float> ConvertAnyUtil(const ValuePtr &value, const AnyTraits<std::vector<float>>, const AnyTraits<float>) {
  70. MS_EXCEPTION_IF_NULL(value);
  71. auto vec = value->cast<ValueTuplePtr>();
  72. if (nullptr == vec) {
  73. MS_LOG(EXCEPTION) << "not ValueTuplePtr";
  74. }
  75. std::vector<float> list;
  76. list.resize(vec->value().size());
  77. (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
  78. [](const ValuePtr &val) { return static_cast<float>(GetValue<float>(val)); });
  79. return list;
  80. }
  81. std::vector<int64_t> ConvertAnyUtil(const ValuePtr &value, const std::string &format,
  82. const AnyTraits<std::vector<int64_t>>, const AnyTraits<int64_t>) {
  83. MS_EXCEPTION_IF_NULL(value);
  84. auto vec = value->cast<ValueTuplePtr>();
  85. if (nullptr == vec) {
  86. MS_LOG(EXCEPTION) << "not ValueTuplePtr";
  87. }
  88. std::vector<int64_t> list;
  89. list.resize(vec->value().size());
  90. (void)std::transform(vec->value().begin(), vec->value().end(), list.begin(),
  91. [](const ValuePtr &val) { return static_cast<int64_t>(GetValue<int>(val)); });
  92. if (format == kOpFormat_NHWC) {
  93. if (list.size() < 4) {
  94. MS_LOG(EXCEPTION) << "The size of list is less than 4";
  95. } else {
  96. int64_t temp = list[1];
  97. list[1] = list[2];
  98. list[2] = list[3];
  99. list[3] = temp;
  100. }
  101. }
  102. return list;
  103. }
  104. GeDataType ConvertAnyUtil(const ValuePtr &value, const AnyTraits<GEType>) {
  105. MS_EXCEPTION_IF_NULL(value);
  106. if (!value->isa<Type>()) {
  107. MS_LOG(EXCEPTION) << "error convert Value to TypePtr for value: " << value->ToString()
  108. << ", type: " << value->type_name() << ", value should be a Typeptr";
  109. }
  110. auto type = value->cast<TypePtr>();
  111. MS_EXCEPTION_IF_NULL(type);
  112. TypeId me_type = type->type_id();
  113. if (kObjectTypeTensorType == me_type) {
  114. me_type = dyn_cast<TensorType>(type)->element()->type_id();
  115. }
  116. return TransformUtil::ConvertDataType(me_type);
  117. }
  118. GeTensor VectorToTensorUtil(const ValuePtr &value) {
  119. // convert tuple or list to ge tensor, only supported one dim for now
  120. MS_EXCEPTION_IF_NULL(value);
  121. auto vec = value->isa<ValueTuple>() ? value->cast<ValueTuplePtr>()->value() : value->cast<ValueListPtr>()->value();
  122. if (vec.empty()) {
  123. MS_LOG(WARNING) << "Convert a none tuple to an empty ge tensor";
  124. return GeTensor();
  125. }
  126. MS_EXCEPTION_IF_NULL(vec[0]);
  127. if (vec[0]->isa<Int32Imm>()) {
  128. MS_LOG(INFO) << "convert value to tensor with data type = Int32";
  129. auto data = ConvertAnyUtil(value, AnyTraits<int32_t>(), AnyTraits<std::vector<int32_t>>());
  130. auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeInt32, kOpFormat_NCHW);
  131. if (desc == nullptr) {
  132. MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
  133. }
  134. return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(int32_t));
  135. } else if (vec[0]->isa<FP32Imm>()) {
  136. MS_LOG(INFO) << "convert value to tensor with data type = Float32";
  137. auto data = ConvertAnyUtil(value, AnyTraits<float>(), AnyTraits<std::vector<float>>());
  138. auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeFloat32, kOpFormat_NCHW);
  139. if (desc == nullptr) {
  140. MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
  141. }
  142. return GeTensor(*desc, reinterpret_cast<uint8_t *>(data.data()), data.size() * sizeof(float));
  143. } else if (vec[0]->isa<BoolImm>()) {
  144. MS_LOG(INFO) << "convert value to tensor with data type = Bool";
  145. // We use uint8_t to save bool type data
  146. auto data = ConvertAnyUtil(value, AnyTraits<bool>(), AnyTraits<std::vector<uint8_t>>());
  147. auto desc = TransformUtil::GetGeTensorDesc({static_cast<int>(vec.size())}, kNumberTypeBool, kOpFormat_NCHW);
  148. if (desc == nullptr) {
  149. MS_LOG(EXCEPTION) << "Update conversion descriptor failed!";
  150. }
  151. return GeTensor(*desc, static_cast<uint8_t *>(data.data()), data.size() * sizeof(uint8_t));
  152. } else {
  153. MS_LOG(EXCEPTION) << "Unsupported data type of tuple or list elements: " << vec[0]->type_name();
  154. }
  155. return GeTensor();
  156. }
  157. GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits<AnyValue>) {
  158. MS_EXCEPTION_IF_NULL(value);
  159. if (value->isa<MeTensor>()) {
  160. // convert me tensor to ge tensor
  161. return ConvertAnyUtil(value, AnyTraits<MeTensor>());
  162. } else if (value->isa<ValueList>() || value->isa<ValueTuple>()) {
  163. return VectorToTensorUtil(value);
  164. } else if (value->isa<Int32Imm>()) {
  165. // convert scalar Int to GeTensor
  166. MS_LOG(INFO) << "convert scalar to tensor with data type = Int32";
  167. GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
  168. auto v = GetValue<int32_t>(value);
  169. desc.SetRealDimCnt(0);
  170. return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int32_t));
  171. } else if (value->isa<Int64Imm>()) {
  172. // convert scalar Int64 to GeTensor
  173. MS_LOG(INFO) << "convert scalar to tensor with data type = Int64";
  174. GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT64);
  175. auto v = GetValue<int64_t>(value);
  176. desc.SetRealDimCnt(0);
  177. return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(int64_t));
  178. } else if (value->isa<FP32Imm>()) {
  179. // convert scalar FP32 to GeTensor
  180. MS_LOG(INFO) << "convert scalar to tensor with data type = FP32";
  181. GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
  182. auto v = GetValue<float>(value);
  183. desc.SetRealDimCnt(0);
  184. return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(float));
  185. } else if (value->isa<BoolImm>()) {
  186. // convert scalar FP32 to GeTensor
  187. MS_LOG(INFO) << "convert scalar to tensor with data type = Bool";
  188. GeTensorDesc desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL);
  189. auto v = GetValue<bool>(value);
  190. desc.SetRealDimCnt(0);
  191. return GeTensor(desc, reinterpret_cast<uint8_t *>(&v), sizeof(bool));
  192. } else if (value->isa<StringImm>()) {
  193. // convert String to GeTensor
  194. MS_LOG(INFO) << "convert string to tensor with data type = String";
  195. std::string v = GetValue<std::string>(value);
  196. std::vector<int64_t> ge_shape;
  197. GeShape shape(ge_shape);
  198. GeTensorDesc desc(shape, ge::FORMAT_NCHW, ge::DT_STRING);
  199. GeTensor str_tensor(desc);
  200. str_tensor.SetData(v);
  201. return str_tensor;
  202. } else {
  203. MS_LOG(WARNING) << "Unsupported value type: " << value->type_name()
  204. << " to convert to tensor. Value: " << value->ToString();
  205. }
  206. return GeTensor();
  207. }
  208. bool IsCustomPrim(const PrimitivePtr &prim) {
  209. if (prim == nullptr) {
  210. return false;
  211. }
  212. ValuePtr flag = prim->GetAttr("_custom_op_flag");
  213. if (flag == nullptr) {
  214. return false;
  215. }
  216. bool is_custom_op = GetValue<bool>(flag);
  217. if (!is_custom_op && prim->GetAttr("_custom_op_impl_config_path") != nullptr) {
  218. MS_LOG(EXCEPTION) << "The custom op flag is false, but the op information config path is not null, non-custom op "
  219. "can not assign the op information config path.";
  220. }
  221. return is_custom_op;
  222. }
  223. bool IsCustomCNode(const AnfNodePtr &anf) {
  224. if (anf == nullptr) {
  225. return false;
  226. }
  227. auto node = anf->cast<CNodePtr>();
  228. if (node == nullptr) {
  229. return false;
  230. }
  231. if (node->inputs().empty()) {
  232. MS_LOG(EXCEPTION) << "length of node inputs is empty";
  233. }
  234. MS_EXCEPTION_IF_NULL(node->inputs()[0]);
  235. if (!node->inputs()[0]->isa<ValueNode>()) {
  236. return false;
  237. }
  238. auto cus_prim = GetValueNode<PrimitivePtr>(node->inputs()[0]);
  239. if (cus_prim == nullptr) {
  240. return false;
  241. }
  242. return IsCustomPrim(cus_prim);
  243. }
  244. } // namespace transform
  245. } // namespace mindspore