You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

README.md 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # MindSpore Lite 端侧风格迁移demo(Android)
  2. 本示例程序演示了如何在端侧利用MindSpore Lite API以及MindSpore Lite风格迁移模型完成端侧推理,根据demo内置的标准图片更换目标图片的艺术风格,并在App图像预览界面中显示出来。
  3. ## 运行依赖
  4. - Android Studio >= 3.2 (推荐4.0以上版本)
  5. - NDK 21.3
  6. - CMake 3.10
  7. - Android SDK >= 26
  8. ## 构建与运行
  9. 1. 在Android Studio中加载本示例源码,并安装相应的SDK(指定SDK版本后,由Android Studio自动安装)。
  10. ![start_home](images/home.png)
  11. 启动Android Studio后,点击`File->Settings->System Settings->Android SDK`,勾选相应的SDK。如下图所示,勾选后,点击`OK`,Android Studio即可自动安装SDK。
  12. ![start_sdk](images/sdk_management.png)
  13. 使用过程中若出现Android Studio配置问题,可参考第5项解决。
  14. 2. 连接Android设备,运行骨应用程序。
  15. 通过USB连接Android设备调试,点击`Run 'app'`即可在你的设备上运行本示例项目。
  16. > 编译过程中Android Studio会自动下载MindSpore Lite、模型文件等相关依赖项,编译过程需做耐心等待。
  17. ![run_app](images/run_app.PNG)
  18. Android Studio连接设备调试操作,可参考<https://developer.android.com/studio/run/device?hl=zh-cn>。
  19. 3. 在Android设备上,点击“继续安装”,安装完即可查看到推理结果。
  20. ![install](images/install.jpg)
  21. 使用风格迁移demo时,用户可先导入或拍摄自己的图片,然后选择一种预置风格,得到推理后的新图片,最后使用还原或保存新图片功能。
  22. 原始图片:
  23. ![sult](images/style_transfer_demo.png)
  24. 风格迁移后的新图片:
  25. ![sult](images/style_transfer_result.png)
  26. 4. Android Studio 配置问题解决方案可参考下表:
  27. | | 报错 | 解决方案 |
  28. | ---- | ------------------------------------------------------------ | ------------------------------------------------------------ |
  29. | 1 | Gradle sync failed: NDK not configured. | 在local.properties中指定安装的ndk目录:ndk.dir={ndk的安装目录} |
  30. | 2 | Requested NDK version did not match the version requested by ndk.dir | 可手动下载相应的[NDK版本](https://developer.android.com/ndk/downloads?hl=zh-cn),并在Project Structure - Android NDK location设置中指定SDK的位置(可参考下图完成) |
  31. | 3 | This version of Android Studio cannot open this project, please retry with Android Studio or newer. | 在工具栏-help-Checkout for Updates中更新版本 |
  32. | 4 | SSL peer shut down incorrectly | 重新构建 |
  33. ![project_structure](images/project_structure.png)
  34. ## 示例程序详细说明
  35. 风格Android示例程序通过Android Camera 2 API实现摄像头获取图像帧,以及相应的图像处理等功能,在[Runtime](https://www.mindspore.cn/tutorial/lite/zh-CN/master/use/runtime.html)中完成模型推理的过程。
  36. ### 示例程序结构
  37. ```text
  38. ├── app
  39. │   ├── build.gradle # 其他Android配置文件
  40. │   ├── download.gradle # APP构建时由gradle自动从HuaWei Server下载依赖的库文件及模型文件
  41. │   ├── proguard-rules.pro
  42. │   └── src
  43. │   ├── main
  44. │   │   ├── AndroidManifest.xml # Android配置文件
  45. │   │   ├── java # java层应用代码
  46. │   │   │   └── com
  47. │   │   │   └── mindspore
  48. │   │   │   └── posenetdemo # 图像处理及推理流程实现
  49. │   │   │   ├── CameraDataDealListener.java
  50. │   │   │   ├── ImageUtils.java
  51. │   │   │   ├── MainActivity.java
  52. │   │   │   ├── PoseNetFragment.java
  53. │   │   │   ├── Posenet.java #
  54. │   │   │   └── TestActivity.java
  55. │   │   └── res # 存放Android相关的资源文件
  56. │   └── test
  57. └── ...
  58. ```
  59. ### 下载及部署模型文件
  60. 从MindSpore Model Hub中下载模型文件,本示例程序中使用的目标检测模型文件为`posenet_model.ms`,同样通过`download.gradle`脚本在APP构建时自动下载,并放置在`app/src/main/assets`工程目录下。
  61. > 若下载失败请手动下载模型文件,style_predict_quant.ms [下载链接](https://download.mindspore.cn/model_zoo/official/lite/style_lite/style_predict_quant.ms),以及style_transfer_quant.ms [下载链接](https://download.mindspore.cn/model_zoo/official/lite/style_lite/style_transfer_quant.ms)。
  62. ### 编写端侧推理代码
  63. 在风格迁移demo中,使用Java API实现端测推理。相比于C++ API,Java API可以直接在Java Class中调用,无需实现JNI层的相关代码,具有更好的便捷性。
  64. 风格迁移demo推理代码流程如下,完整代码请参见:`src/main/java/com/mindspore/styletransferdemo/StyleTransferModelExecutor.java`。
  65. 1. 加载MindSpore Lite模型文件,构建上下文、会话以及用于推理的计算图。
  66. - 加载模型:从文件系统中读取MindSpore Lite模型,并进行模型解析。
  67. ```java
  68. // Load the .ms model.
  69. style_predict_model = new Model();
  70. if (!style_predict_model.loadModel(mContext, "style_predict_quant.ms")) {
  71. Log.e("MS_LITE", "Load style_predict_model failed");
  72. }
  73. style_transform_model = new Model();
  74. if (!style_transform_model.loadModel(mContext, "style_transfer_quant.ms")) {
  75. Log.e("MS_LITE", "Load style_transform_model failed");
  76. }
  77. ```
  78. - 创建配置上下文:创建配置上下文`MSConfig`,保存会话所需的一些基本配置参数,用于指导图编译和图执行。
  79. ```java
  80. msConfig = new MSConfig();
  81. if (!msConfig.init(DeviceType.DT_CPU, NUM_THREADS, CpuBindMode.MID_CPU)) {
  82. Log.e("MS_LITE", "Init context failed");
  83. }
  84. ```
  85. - 创建会话:创建`LiteSession`,并调用`init`方法将上一步得到`MSConfig`配置到会话中。
  86. ```java
  87. // Create the MindSpore lite session.
  88. Predict_session = new LiteSession();
  89. if (!Predict_session.init(msConfig)) {
  90. Log.e("MS_LITE", "Create Predict_session failed");
  91. msConfig.free();
  92. }
  93. Transform_session = new LiteSession();
  94. if (!Transform_session.init(msConfig)) {
  95. Log.e("MS_LITE", "Create Predict_session failed");
  96. msConfig.free();
  97. }
  98. msConfig.free();
  99. ```
  100. - 加载模型文件并构建用于推理的计算图
  101. ```java
  102. // Complile graph.
  103. if (!Predict_session.compileGraph(style_predict_model)) {
  104. Log.e("MS_LITE", "Compile style_predict graph failed");
  105. style_predict_model.freeBuffer();
  106. }
  107. if (!Transform_session.compileGraph(style_transform_model)) {
  108. Log.e("MS_LITE", "Compile style_transform graph failed");
  109. style_transform_model.freeBuffer();
  110. }
  111. // Note: when use model.freeBuffer(), the model can not be complile graph again.
  112. style_predict_model.freeBuffer();
  113. style_transform_model.freeBuffer();
  114. ```
  115. 2. 输入数据: Java目前支持`byte[]`或者`ByteBuffer`两种类型的数据,设置输入Tensor的数据。
  116. - 在输入数据之前,将float数组转换为byte数组。
  117. ```java
  118. public static byte[] floatArrayToByteArray(float[] floats) {
  119. ByteBuffer buffer = ByteBuffer.allocate(4 * floats.length);
  120. buffer.order(ByteOrder.nativeOrder());
  121. FloatBuffer floatBuffer = buffer.asFloatBuffer();
  122. floatBuffer.put(floats);
  123. return buffer.array();
  124. }
  125. ```
  126. - 通过`ByteBuffer`输入数据。`contentImage`为用户提供的图片,`styleBitmap`为预置风格图片。
  127. ```java
  128. public ModelExecutionResult execute(Bitmap contentImage, Bitmap styleBitmap) {
  129. Log.i(TAG, "running models");
  130. fullExecutionTime = SystemClock.uptimeMillis();
  131. preProcessTime = SystemClock.uptimeMillis();
  132. ByteBuffer contentArray =
  133. ImageUtils.bitmapToByteBuffer(contentImage, CONTENT_IMAGE_SIZE, CONTENT_IMAGE_SIZE, 0, 255);
  134. ByteBuffer input = ImageUtils.bitmapToByteBuffer(styleBitmap, STYLE_IMAGE_SIZE, STYLE_IMAGE_SIZE, 0, 255);
  135. ```
  136. 3. 对输入Tensor按照模型进行推理,获取输出Tensor,并进行后处理。
  137. - 使用`runGraph`对预置图片进行模型推理,并获取结果`Predict_results`。
  138. ```java
  139. List<MSTensor> Predict_inputs = Predict_session.getInputs();
  140. if (Predict_inputs.size() != 1) {
  141. return null;
  142. }
  143. MSTensor Predict_inTensor = Predict_inputs.get(0);
  144. Predict_inTensor.setData(input);
  145. preProcessTime = SystemClock.uptimeMillis() - preProcessTime;
  146. stylePredictTime = SystemClock.uptimeMillis();
  147. if (!Predict_session.runGraph()) {
  148. Log.e("MS_LITE", "Run Predict_graph failed");
  149. return null;
  150. }
  151. stylePredictTime = SystemClock.uptimeMillis() - stylePredictTime;
  152. Log.d(TAG, "Style Predict Time to run: " + stylePredictTime);
  153. // Get output tensor values.
  154. List<String> tensorNames = Predict_session.getOutputTensorNames();
  155. Map<String, MSTensor> outputs = Predict_session.getOutputMapByTensor();
  156. Set<Map.Entry<String, MSTensor>> entrys = outputs.entrySet();
  157. float[] Predict_results = null;
  158. for (String tensorName : tensorNames) {
  159. MSTensor output = outputs.get(tensorName);
  160. if (output == null) {
  161. Log.e("MS_LITE", "Can not find Predict_session output " + tensorName);
  162. return null;
  163. }
  164. int type = output.getDataType();
  165. Predict_results = output.getFloatData();
  166. }
  167. ```
  168. - 利用上一步获取的结果,再次对用户图片进行模型推理,得到风格转换的数据`transform_results`。
  169. ```java
  170. List<MSTensor> Transform_inputs = Transform_session.getInputs();
  171. // transform model have 2 input tensor, tensor0: 1*1*1*100, tensor1;1*384*384*3
  172. MSTensor Transform_inputs_inTensor0 = Transform_inputs.get(0);
  173. Transform_inputs_inTensor0.setData(floatArrayToByteArray(Predict_results));
  174. MSTensor Transform_inputs_inTensor1 = Transform_inputs.get(1);
  175. Transform_inputs_inTensor1.setData(contentArray);
  176. styleTransferTime = SystemClock.uptimeMillis();
  177. if (!Transform_session.runGraph()) {
  178. Log.e("MS_LITE", "Run Transform_graph failed");
  179. return null;
  180. }
  181. styleTransferTime = SystemClock.uptimeMillis() - styleTransferTime;
  182. Log.d(TAG, "Style apply Time to run: " + styleTransferTime);
  183. postProcessTime = SystemClock.uptimeMillis();
  184. // Get output tensor values.
  185. List<String> Transform_tensorNames = Transform_session.getOutputTensorNames();
  186. Map<String, MSTensor> Transform_outputs = Transform_session.getOutputMapByTensor();
  187. float[] transform_results = null;
  188. for (String tensorName : Transform_tensorNames) {
  189. MSTensor output1 = Transform_outputs.get(tensorName);
  190. if (output1 == null) {
  191. Log.e("MS_LITE", "Can not find Transform_session output " + tensorName);
  192. return null;
  193. }
  194. transform_results = output1.getFloatData();
  195. }
  196. ```
  197. - 对输出节点的数据进行处理,得到推理后的最终结果。
  198. ```java
  199. float[][][][] outputImage = new float[1][][][]; // 1 384 384 3
  200. for (int x = 0; x < 1; x++) {
  201. float[][][] arrayThree = new float[CONTENT_IMAGE_SIZE][][];
  202. for (int y = 0; y < CONTENT_IMAGE_SIZE; y++) {
  203. float[][] arrayTwo = new float[CONTENT_IMAGE_SIZE][];
  204. for (int z = 0; z < CONTENT_IMAGE_SIZE; z++) {
  205. float[] arrayOne = new float[3];
  206. for (int i = 0; i < 3; i++) {
  207. int n = i + z * 3 + y * CONTENT_IMAGE_SIZE * 3 + x * CONTENT_IMAGE_SIZE * CONTENT_IMAGE_SIZE * 3;
  208. arrayOne[i] = transform_results[n];
  209. }
  210. arrayTwo[z] = arrayOne;
  211. }
  212. arrayThree[y] = arrayTwo;
  213. }
  214. outputImage[x] = arrayThree;
  215. }
  216. Bitmap styledImage =
  217. ImageUtils.convertArrayToBitmap(outputImage, CONTENT_IMAGE_SIZE, CONTENT_IMAGE_SIZE);
  218. postProcessTime = SystemClock.uptimeMillis() - postProcessTime;
  219. fullExecutionTime = SystemClock.uptimeMillis() - fullExecutionTime;
  220. Log.d(TAG, "Time to run everything: $" + fullExecutionTime);
  221. return new ModelExecutionResult(styledImage,
  222. preProcessTime,
  223. stylePredictTime,
  224. styleTransferTime,
  225. postProcessTime,
  226. fullExecutionTime,
  227. formatExecutionLog());
  228. ```