You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.cs 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Text;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.Examples.ImageProcessing.YOLO
  7. {
  8. class common
  9. {
  10. public static Tensor convolutional(Tensor input_data, int[] filters_shape, Tensor trainable,
  11. string name, bool downsample = false, bool activate = true,
  12. bool bn = true)
  13. {
  14. return tf_with(tf.variable_scope(name), scope =>
  15. {
  16. int[] strides;
  17. string padding;
  18. if (downsample)
  19. {
  20. (int pad_h, int pad_w) = ((int)Math.Floor((filters_shape[0] - 2) / 2.0f) + 1, (int)Math.Floor((filters_shape[1] - 2) / 2.0f) + 1);
  21. var paddings = tf.constant(new int[,] { { 0, 0 }, { pad_h, pad_h }, { pad_w, pad_w }, { 0, 0 } });
  22. input_data = tf.pad(input_data, paddings, "CONSTANT");
  23. throw new NotImplementedException("");
  24. }
  25. else
  26. {
  27. strides = new int[] { 1, 1, 1, 1 };
  28. padding = "SAME";
  29. }
  30. var weight = tf.get_variable(name: "weight", dtype: tf.float32, trainable: true,
  31. shape: filters_shape, initializer: tf.random_normal_initializer(stddev: 0.01f));
  32. var conv = tf.nn.conv2d(input: input_data, filter: weight, strides: strides, padding: padding);
  33. if (bn)
  34. {
  35. conv = tf.layers.batch_normalization(conv, beta_initializer: tf.zeros_initializer,
  36. gamma_initializer: tf.ones_initializer,
  37. moving_mean_initializer: tf.zeros_initializer,
  38. moving_variance_initializer: tf.ones_initializer, training: trainable);
  39. }
  40. else
  41. {
  42. throw new NotImplementedException("");
  43. }
  44. if (activate)
  45. conv = tf.nn.leaky_relu(conv, alpha: 0.1f);
  46. return conv;
  47. });
  48. }
  49. public static Tensor residual_block(Tensor input_data, int input_channel, int filter_num1,
  50. int filter_num2, Tensor trainable, string name)
  51. {
  52. var short_cut = input_data;
  53. return tf_with(tf.variable_scope(name), scope =>
  54. {
  55. input_data = convolutional(input_data, filters_shape: new int[] { 1, 1, input_channel, filter_num1 },
  56. trainable: trainable, name: "conv1");
  57. input_data = convolutional(input_data, filters_shape: new int[] { 3, 3, filter_num1, filter_num2 },
  58. trainable: trainable, name: "conv2");
  59. var residual_output = input_data + short_cut;
  60. return residual_output;
  61. });
  62. }
  63. }
  64. }