You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

RetrainImageClassifier.cs 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. using NumSharp;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Diagnostics;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Text;
  8. using Tensorflow;
  9. using TensorFlowNET.Examples.Utility;
  10. using static Tensorflow.Python;
  11. namespace TensorFlowNET.Examples.ImageProcess
  12. {
  13. /// <summary>
  14. /// In this tutorial, we will reuse the feature extraction capabilities from powerful image classifiers trained on ImageNet
  15. /// and simply train a new classification layer on top. Transfer learning is a technique that shortcuts much of this
  16. /// by taking a piece of a model that has already been trained on a related task and reusing it in a new model.
  17. ///
  18. /// https://www.tensorflow.org/hub/tutorials/image_retraining
  19. /// </summary>
  20. public class RetrainImageClassifier : IExample
  21. {
  22. public int Priority => 16;
  23. public bool Enabled { get; set; } = false;
  24. public bool ImportGraph { get; set; } = true;
  25. public string Name => "Retrain Image Classifier";
  26. const string data_dir = "retrain_images";
  27. string summaries_dir = Path.Join(data_dir, "retrain_logs");
  28. string image_dir = Path.Join(data_dir, "flower_photos");
  29. string bottleneck_dir = Path.Join(data_dir, "bottleneck");
  30. string tfhub_module = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3";
  31. float testing_percentage = 0.1f;
  32. float validation_percentage = 0.1f;
  33. Tensor resized_image_tensor;
  34. Dictionary<string, Dictionary<string, string[]>> image_lists;
  35. public bool Run()
  36. {
  37. PrepareData();
  38. var graph = tf.Graph().as_default();
  39. tf.train.import_meta_graph("graph/InceptionV3.meta");
  40. Tensor bottleneck_tensor = graph.OperationByName("module_apply_default/hub_output/feature_vector/SpatialSqueeze");
  41. Tensor resized_image_tensor = graph.OperationByName("Placeholder");
  42. var sw = new Stopwatch();
  43. with(tf.Session(graph), sess =>
  44. {
  45. // Initialize all weights: for the module to their pretrained values,
  46. // and for the newly added retraining layer to random initial values.
  47. var init = tf.global_variables_initializer();
  48. sess.run(init);
  49. var (jpeg_data_tensor, decoded_image_tensor) = add_jpeg_decoding();
  50. // We'll make sure we've calculated the 'bottleneck' image summaries and
  51. // cached them on disk.
  52. cache_bottlenecks(sess, image_lists, image_dir,
  53. bottleneck_dir, jpeg_data_tensor,
  54. decoded_image_tensor, resized_image_tensor,
  55. bottleneck_tensor, tfhub_module);
  56. });
  57. return false;
  58. }
  59. /// <summary>
  60. /// Ensures all the training, testing, and validation bottlenecks are cached.
  61. /// </summary>
  62. /// <param name="sess"></param>
  63. /// <param name="image_lists"></param>
  64. /// <param name="image_dir"></param>
  65. /// <param name="bottleneck_dir"></param>
  66. /// <param name="jpeg_data_tensor"></param>
  67. /// <param name="decoded_image_tensor"></param>
  68. /// <param name="resized_image_tensor"></param>
  69. /// <param name="bottleneck_tensor"></param>
  70. /// <param name="tfhub_module"></param>
  71. private void cache_bottlenecks(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  72. string image_dir, string bottleneck_dir, Tensor jpeg_data_tensor, Tensor decoded_image_tensor,
  73. Tensor resized_input_tensor, Tensor bottleneck_tensor, string module_name)
  74. {
  75. int how_many_bottlenecks = 0;
  76. foreach(var (label_name, label_lists) in image_lists)
  77. {
  78. foreach(var category in new string[] { "training", "testing", "validation" })
  79. {
  80. var category_list = label_lists[category];
  81. foreach(var (index, unused_base_name) in enumerate(category_list))
  82. {
  83. get_or_create_bottleneck(sess, image_lists, label_name, index, image_dir, category,
  84. bottleneck_dir, jpeg_data_tensor, decoded_image_tensor,
  85. resized_input_tensor, bottleneck_tensor, module_name);
  86. }
  87. }
  88. }
  89. }
  90. private void get_or_create_bottleneck(Session sess, Dictionary<string, Dictionary<string, string[]>> image_lists,
  91. string label_name, int index, string image_dir, string category, string bottleneck_dir,
  92. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor,
  93. Tensor bottleneck_tensor, string module_name)
  94. {
  95. var label_lists = image_lists[label_name];
  96. var sub_dir_path = Path.Join(bottleneck_dir, label_name);
  97. Directory.CreateDirectory(sub_dir_path);
  98. string bottleneck_path = get_bottleneck_path(image_lists, label_name, index,
  99. bottleneck_dir, category, module_name);
  100. if (!File.Exists(bottleneck_path))
  101. create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
  102. image_dir, category, sess, jpeg_data_tensor,
  103. decoded_image_tensor, resized_input_tensor,
  104. bottleneck_tensor);
  105. }
  106. private void create_bottleneck_file(string bottleneck_path, Dictionary<string, Dictionary<string, string[]>> image_lists,
  107. string label_name, int index, string image_dir, string category, Session sess,
  108. Tensor jpeg_data_tensor, Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  109. {
  110. // Create a single bottleneck file.
  111. print("Creating bottleneck at " + bottleneck_path);
  112. var image_path = get_image_path(image_lists, label_name, index, image_dir, category);
  113. if (!File.Exists(image_path))
  114. print($"File does not exist {image_path}");
  115. var image_data = File.ReadAllBytes(image_path);
  116. var bottleneck_values = run_bottleneck_on_image(
  117. sess, image_data, jpeg_data_tensor, decoded_image_tensor,
  118. resized_input_tensor, bottleneck_tensor);
  119. }
  120. /// <summary>
  121. /// Runs inference on an image to extract the 'bottleneck' summary layer.
  122. /// </summary>
  123. /// <param name="sess">Current active TensorFlow Session.</param>
  124. /// <param name="image_path">Path of raw JPEG data.</param>
  125. /// <param name="image_data_tensor">Input data layer in the graph.</param>
  126. /// <param name="decoded_image_tensor">Output of initial image resizing and preprocessing.</param>
  127. /// <param name="resized_input_tensor">The input node of the recognition graph.</param>
  128. /// <param name="bottleneck_tensor">Layer before the final softmax.</param>
  129. /// <returns></returns>
  130. private NDArray run_bottleneck_on_image(Session sess, byte[] image_data, Tensor image_data_tensor,
  131. Tensor decoded_image_tensor, Tensor resized_input_tensor, Tensor bottleneck_tensor)
  132. {
  133. // First decode the JPEG image, resize it, and rescale the pixel values.
  134. var resized_input_values = sess.run(decoded_image_tensor, new FeedItem(image_data_tensor, image_data));
  135. // Then run it through the recognition network.
  136. var bottleneck_values = sess.run(bottleneck_tensor, new FeedItem(resized_input_tensor, resized_input_values));
  137. bottleneck_values = np.squeeze(bottleneck_values);
  138. return bottleneck_values;
  139. }
  140. private string get_bottleneck_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name, int index,
  141. string bottleneck_dir, string category, string module_name)
  142. {
  143. module_name = (module_name.Replace("://", "~") // URL scheme.
  144. .Replace('/', '~') // URL and Unix paths.
  145. .Replace(':', '~').Replace('\\', '~')); // Windows paths.
  146. return get_image_path(image_lists, label_name, index, bottleneck_dir,
  147. category) + "_" + module_name + ".txt";
  148. }
  149. private string get_image_path(Dictionary<string, Dictionary<string, string[]>> image_lists, string label_name,
  150. int index, string image_dir, string category)
  151. {
  152. if (!image_lists.ContainsKey(label_name))
  153. print($"Label does not exist {label_name}");
  154. var label_lists = image_lists[label_name];
  155. if (!label_lists.ContainsKey(category))
  156. print($"Category does not exist {category}");
  157. var category_list = label_lists[category];
  158. if (category_list.Length == 0)
  159. print($"Label {label_name} has no images in the category {category}.");
  160. var mod_index = index % len(category_list);
  161. var base_name = category_list[mod_index].Split(Path.DirectorySeparatorChar).Last();
  162. var sub_dir = label_name;
  163. var full_path = Path.Join(image_dir, sub_dir, base_name);
  164. return full_path;
  165. }
  166. public void PrepareData()
  167. {
  168. // get a set of images to teach the network about the new classes
  169. string fileName = "flower_photos.tgz";
  170. string url = $"http://download.tensorflow.org/models/{fileName}";
  171. Web.Download(url, data_dir, fileName);
  172. Compress.ExtractTGZ(Path.Join(data_dir, fileName), data_dir);
  173. // download graph meta data
  174. url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/InceptionV3.meta";
  175. Web.Download(url, "graph", "InceptionV3.meta");
  176. // Prepare necessary directories that can be used during training
  177. Directory.CreateDirectory(summaries_dir);
  178. Directory.CreateDirectory(bottleneck_dir);
  179. // Look at the folder structure, and create lists of all the images.
  180. image_lists = create_image_lists();
  181. var class_count = len(image_lists);
  182. if (class_count == 0)
  183. print($"No valid folders of images found at {image_dir}");
  184. if (class_count == 1)
  185. print("Only one valid folder of images found at " +
  186. image_dir +
  187. " - multiple classes are needed for classification.");
  188. }
  189. private (Tensor, Tensor) add_jpeg_decoding()
  190. {
  191. // height, width, depth
  192. var input_dim = (299, 299, 3);
  193. var jpeg_data = tf.placeholder(tf.chars, name: "DecodeJPGInput");
  194. var decoded_image = tf.image.decode_jpeg(jpeg_data, channels: input_dim.Item3);
  195. // Convert from full range of uint8 to range [0,1] of float32.
  196. var decoded_image_as_float = tf.image.convert_image_dtype(decoded_image, tf.float32);
  197. var decoded_image_4d = tf.expand_dims(decoded_image_as_float, 0);
  198. var resize_shape = tf.stack(new int[] { input_dim.Item1, input_dim.Item2 });
  199. var resize_shape_as_int = tf.cast(resize_shape, dtype: tf.int32);
  200. var resized_image = tf.image.resize_bilinear(decoded_image_4d, resize_shape_as_int);
  201. return (jpeg_data, resized_image);
  202. }
  203. /// <summary>
  204. /// Builds a list of training images from the file system.
  205. /// </summary>
  206. private Dictionary<string, Dictionary<string, string[]>> create_image_lists()
  207. {
  208. var sub_dirs = tf.gfile.Walk(image_dir)
  209. .Select(x => x.Item1)
  210. .OrderBy(x => x)
  211. .ToArray();
  212. var result = new Dictionary<string, Dictionary<string, string[]>>();
  213. foreach(var sub_dir in sub_dirs)
  214. {
  215. var dir_name = sub_dir.Split(Path.DirectorySeparatorChar).Last();
  216. print($"Looking for images in '{dir_name}'");
  217. var file_list = Directory.GetFiles(sub_dir);
  218. if (len(file_list) < 20)
  219. print($"WARNING: Folder has less than 20 images, which may cause issues.");
  220. var label_name = dir_name.ToLower();
  221. result[label_name] = new Dictionary<string, string[]>();
  222. int testing_count = (int)Math.Floor(file_list.Length * testing_percentage);
  223. int validation_count = (int)Math.Floor(file_list.Length * validation_percentage);
  224. result[label_name]["testing"] = file_list.Take(testing_count).ToArray();
  225. result[label_name]["validation"] = file_list.Skip(testing_count).Take(validation_count).ToArray();
  226. result[label_name]["training"] = file_list.Skip(testing_count + validation_count).ToArray();
  227. }
  228. return result;
  229. }
  230. }
  231. }