You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

WordCnn.cs 3.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Tensorflow;
  6. using static Tensorflow.Python;
  7. namespace TensorFlowNET.Examples.Text
  8. {
  9. public class WordCnn : ITextModel
  10. {
  11. public WordCnn(int vocabulary_size, int document_max_len, int num_class)
  12. {
  13. var embedding_size = 128;
  14. var learning_rate = 0.001f;
  15. var filter_sizes = new int[3, 4, 5];
  16. var num_filters = 100;
  17. var x = tf.placeholder(tf.int32, new TensorShape(-1, document_max_len), name: "x");
  18. var y = tf.placeholder(tf.int32, new TensorShape(-1), name: "y");
  19. var is_training = tf.placeholder(tf.@bool, new TensorShape(), name: "is_training");
  20. var global_step = tf.Variable(0, trainable: false);
  21. var keep_prob = tf.where(is_training, 0.5f, 1.0f);
  22. Tensor x_emb = null;
  23. with(tf.name_scope("embedding"), scope =>
  24. {
  25. var init_embeddings = tf.random_uniform(new int[] { vocabulary_size, embedding_size });
  26. var embeddings = tf.get_variable("embeddings", initializer: init_embeddings);
  27. x_emb = tf.nn.embedding_lookup(embeddings, x);
  28. x_emb = tf.expand_dims(x_emb, -1);
  29. });
  30. var pooled_outputs = new List<Tensor>();
  31. for (int len = 0; len < filter_sizes.Rank; len++)
  32. {
  33. int filter_size = filter_sizes.GetLength(len);
  34. var conv = tf.layers.conv2d(
  35. x_emb,
  36. filters: num_filters,
  37. kernel_size: new int[] { filter_size, embedding_size },
  38. strides: new int[] { 1, 1 },
  39. padding: "VALID",
  40. activation: tf.nn.relu());
  41. var pool = tf.layers.max_pooling2d(
  42. conv,
  43. pool_size: new[] { document_max_len - filter_size + 1, 1 },
  44. strides: new[] { 1, 1 },
  45. padding: "VALID");
  46. pooled_outputs.Add(pool);
  47. }
  48. var h_pool = tf.concat(pooled_outputs, 3);
  49. var h_pool_flat = tf.reshape(h_pool, new TensorShape(-1, num_filters * filter_sizes.Rank));
  50. Tensor h_drop = null;
  51. with(tf.name_scope("dropout"), delegate
  52. {
  53. h_drop = tf.nn.dropout(h_pool_flat, keep_prob);
  54. });
  55. Tensor logits = null;
  56. Tensor predictions = null;
  57. with(tf.name_scope("output"), delegate
  58. {
  59. logits = tf.layers.dense(h_drop, num_class);
  60. predictions = tf.argmax(logits, -1, output_type: tf.int32);
  61. });
  62. with(tf.name_scope("loss"), delegate
  63. {
  64. var sscel = tf.nn.sparse_softmax_cross_entropy_with_logits(logits: logits, labels: y);
  65. var loss = tf.reduce_mean(sscel);
  66. var adam = tf.train.AdamOptimizer(learning_rate);
  67. var optimizer = adam.minimize(loss, global_step: global_step);
  68. });
  69. with(tf.name_scope("accuracy"), delegate
  70. {
  71. var correct_predictions = tf.equal(predictions, y);
  72. var accuracy = tf.reduce_mean(tf.cast(correct_predictions, TF_DataType.TF_FLOAT), name: "accuracy");
  73. });
  74. }
  75. }
  76. }