You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

kernel_pack.cc 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <fstream>
  17. #include "mindspore/ccsrc/kernel/kernel.h"
  18. #include "kernel/kernel.h"
  19. #include "kernel/akg/akgkernelbuild.h"
  20. #include "nlohmann/json.hpp"
  21. #include "securec/include/securec.h"
  22. #include "pipeline/parse/python_adapter.h"
  23. #include "utils/log_adapter.h"
  24. #include "utils/convert_utils.h"
  25. namespace mindspore {
  26. namespace kernel {
  27. constexpr auto kUtilsModule = "mindspore._extends.utils";
  28. constexpr auto kCalSha256Func = "cal_sha256";
  29. namespace {
  30. bool CheckHash(const std::string &json_file, const std::string &bin_file, const nlohmann::json &js) {
  31. if (js.find("sha256") == js.end()) {
  32. MS_LOG(ERROR) << "No sha256 found in " << json_file;
  33. return false;
  34. }
  35. std::string sha256_str = js["sha256"];
  36. py::object ret = parse::python_adapter::CallPyFn(kUtilsModule, kCalSha256Func, bin_file);
  37. std::string sha256_cal = py::cast<std::string>(ret);
  38. if (sha256_cal.empty()) {
  39. MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed.";
  40. return false;
  41. }
  42. if (sha256_cal != sha256_str) {
  43. MS_LOG(ERROR) << "Cal sha256 of " << bin_file << " failed.";
  44. return false;
  45. }
  46. return true;
  47. }
  48. } // namespace
  49. const std::string KernelPack::Serialize() const {
  50. std::string buffer;
  51. (void)buffer.append((const char *)json_, json_->len + sizeof(json_->len));
  52. (void)buffer.append((const char *)kernel_, kernel_->len + sizeof(kernel_->len));
  53. return buffer;
  54. }
  55. bool KernelPack::ReadFromJsonFileHelper(std::ifstream &kernelbin) {
  56. size_t binsize = LongToSize(kernelbin.seekg(0, std::ios::end).tellg());
  57. // free old data
  58. if (kernel_ != nullptr) {
  59. delete[] kernel_;
  60. kernel_ = nullptr;
  61. }
  62. void *ptr = static_cast<void *>(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]);
  63. if (ptr != nullptr) {
  64. kernel_ = static_cast<FlexArray *>(ptr);
  65. }
  66. if (kernel_ == nullptr) {
  67. MS_LOG(ERROR) << "memory malloc failed.";
  68. kernelbin.close();
  69. return false;
  70. }
  71. if (memset_s(kernel_, sizeof(KernelPack) + binsize, 0, sizeof(KernelPack) + binsize) != EOK) {
  72. MS_LOG(ERROR) << "memset kernel_ failed.";
  73. delete[] kernel_;
  74. kernel_ = nullptr;
  75. kernelbin.close();
  76. return false;
  77. }
  78. kernel_->len = binsize;
  79. MS_LOG(INFO) << "kernel len:" << kernel_->len;
  80. (void)kernelbin.seekg(0, std::ios::beg);
  81. (void)kernelbin.read(kernel_->contents, SizeToLong(kernel_->len));
  82. return true;
  83. }
  84. bool KernelPack::ReadFromJsonFile(const std::string &json_f, const std::string &processor) {
  85. if (json_f.length() <= strlen(kJsonSuffix)) {
  86. MS_LOG(ERROR) << "please check json path.";
  87. return false;
  88. }
  89. std::ifstream kerneljson(json_f);
  90. if (!kerneljson.is_open()) {
  91. MS_LOG(DEBUG) << "read json file error, please check kernelmeta.";
  92. return false;
  93. }
  94. nlohmann::json js;
  95. kerneljson >> js;
  96. size_t binsize = LongToSize(kerneljson.seekg(0, std::ios::end).tellg());
  97. void *ptr = static_cast<void *>(new (std::nothrow) uint8_t[sizeof(KernelPack) + binsize]);
  98. if (ptr != nullptr) {
  99. json_ = static_cast<FlexArray *>(ptr);
  100. }
  101. if (json_ == nullptr) {
  102. MS_LOG(ERROR) << "memory malloc failed.";
  103. kerneljson.close();
  104. return false;
  105. }
  106. json_->len = binsize;
  107. (void)kerneljson.seekg(0, std::ios::beg);
  108. (void)kerneljson.read(json_->contents, SizeToLong(json_->len));
  109. if (processor == kProcessorCuda) {
  110. std::string bin_f = json_f.substr(0, json_f.length() - 5) + ".ptx";
  111. std::ifstream kernelbin(bin_f);
  112. if (!kernelbin.is_open()) {
  113. MS_LOG(ERROR) << "read kernel ptx file error, please check kernelmeta.";
  114. kerneljson.close();
  115. return false;
  116. }
  117. if (ReadFromJsonFileHelper(kernelbin) == false) {
  118. delete[] json_;
  119. json_ = nullptr;
  120. kerneljson.close();
  121. return false;
  122. }
  123. kerneljson.close();
  124. if (!CheckHash(json_f, bin_f, js)) {
  125. return false;
  126. }
  127. return true;
  128. }
  129. std::string binfilesuffix = js["binFileSuffix"];
  130. std::string bin_f = json_f.substr(0, json_f.length() - 5) + binfilesuffix;
  131. if (binfilesuffix.compare(".so") == 0) {
  132. // change "xx/xx.so" -> "xx/libxx.so"
  133. auto sp = bin_f.rfind('/');
  134. if (sp == std::string::npos) {
  135. MS_LOG(ERROR) << "illegal bin file path " << bin_f;
  136. kerneljson.close();
  137. return false;
  138. }
  139. bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1);
  140. }
  141. std::ifstream kernelbin(bin_f, std::ios::binary);
  142. if (!kernelbin.is_open()) {
  143. MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta.";
  144. kerneljson.close();
  145. delete[] json_;
  146. json_ = nullptr;
  147. return false;
  148. }
  149. MS_LOG(INFO) << "kernelbin_name:" << bin_f;
  150. if (ReadFromJsonFileHelper(kernelbin) == false) {
  151. delete[] json_;
  152. json_ = nullptr;
  153. kerneljson.close();
  154. return false;
  155. }
  156. kerneljson.close();
  157. if (!CheckHash(json_f, bin_f, js)) {
  158. return false;
  159. }
  160. return true;
  161. }
  162. void KernelPack::ParseKernelJson(const nlohmann::json &js) {
  163. kernel_json_info_.bin_file_name = js["binFileName"];
  164. kernel_json_info_.bin_file_suffix = js["binFileSuffix"];
  165. kernel_json_info_.block_dim = js["blockDim"];
  166. kernel_json_info_.kernel_name = js["kernelName"];
  167. kernel_json_info_.magic = js["magic"];
  168. if (js.find("parameters") != js.end()) {
  169. if (!js.at("parameters").is_array()) {
  170. MS_LOG(DEBUG) << "Format error!,parameters should be array.";
  171. }
  172. std::vector<size_t> sizes = js.at("parameters");
  173. for (auto size : sizes) {
  174. MS_LOG(INFO) << "parameter " << size;
  175. kernel_json_info_.parameters.push_back(size);
  176. }
  177. }
  178. if (js.find("workspace") != js.end()) {
  179. auto workspace = js.at("workspace");
  180. std::vector<size_t> sizes = workspace.at("size");
  181. for (auto size : sizes) {
  182. MS_LOG(INFO) << "workspace_size_list " << size;
  183. kernel_json_info_.workspaces.push_back(size);
  184. }
  185. }
  186. kernel_json_info_.sha256 = js["sha256"];
  187. }
  188. bool KernelPack::LoadKernelMeta(const std::string &json_f, const std::string &processor) {
  189. if (json_f.length() <= strlen(kJsonSuffix)) {
  190. MS_LOG(ERROR) << "please check json path.";
  191. return false;
  192. }
  193. std::ifstream kernel_json(json_f);
  194. if (!kernel_json.is_open()) {
  195. MS_LOG(DEBUG) << "read json file error, please check kernelmeta.";
  196. return false;
  197. }
  198. nlohmann::json js;
  199. kernel_json >> js;
  200. ParseKernelJson(js);
  201. kernel_json.close();
  202. std::string bin_f = json_f.substr(0, json_f.length() - 5) + kernel_json_info_.bin_file_suffix;
  203. if (kernel_json_info_.bin_file_suffix == ".so") {
  204. // change "xx/xx.so" -> "xx/libxx.so"
  205. auto sp = bin_f.rfind('/');
  206. if (sp == std::string::npos) {
  207. MS_LOG(ERROR) << "illegal bin file path " << bin_f;
  208. return false;
  209. }
  210. bin_f = bin_f.substr(0, sp + 1) + "lib" + bin_f.substr(sp + 1, bin_f.length() - sp - 1);
  211. }
  212. std::ifstream kernelbin(bin_f, std::ios::binary);
  213. if (!kernelbin.is_open()) {
  214. MS_LOG(ERROR) << "read kernel binary file error, please check kernelmeta.";
  215. return false;
  216. }
  217. MS_LOG(INFO) << "kernelbin_name:" << bin_f;
  218. if (!ReadFromJsonFileHelper(kernelbin)) {
  219. return false;
  220. }
  221. return CheckHash(json_f, bin_f, js);
  222. }
  223. KernelJsonInfo KernelPack::kernel_json_info() const { return kernel_json_info_; }
  224. } // namespace kernel
  225. } // namespace mindspore