You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CnnTextTrain.cs 2.3 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. using NumSharp.Core;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Text;
  5. using Tensorflow;
  6. namespace TensorFlowNET.Examples.CnnTextClassification
  7. {
  8. public class CnnTextTrain : Python, IExample
  9. {
  10. // Percentage of the training data to use for validation
  11. private float dev_sample_percentage = 0.1f;
  12. // Data source for the positive data.
  13. private string positive_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.pos";
  14. // Data source for the negative data.
  15. private string negative_data_file = "https://raw.githubusercontent.com/dennybritz/cnn-text-classification-tf/master/data/rt-polaritydata/rt-polarity.neg";
  16. // Dimensionality of character embedding (default: 128)
  17. private int embedding_dim = 128;
  18. // Comma-separated filter sizes (default: '3,4,5')
  19. private string filter_sizes = "3,4,5";
  20. // Number of filters per filter size (default: 128)
  21. private int num_filters = 128;
  22. // Dropout keep probability (default: 0.5)
  23. private float dropout_keep_prob = 0.5f;
  24. // L2 regularization lambda (default: 0.0)
  25. private float l2_reg_lambda = 0.0f;
  26. // Batch Size (default: 64)
  27. private int batch_size = 64;
  28. // Number of training epochs (default: 200)
  29. private int num_epochs = 200;
  30. // Evaluate model on dev set after this many steps (default: 100)
  31. private int evaluate_every = 100;
  32. // Save model after this many steps (default: 100)
  33. private int checkpoint_every = 100;
  34. // Number of checkpoints to store (default: 5)
  35. private int num_checkpoints = 5;
  36. // Allow device soft device placement
  37. private bool allow_soft_placement = true;
  38. // Log placement of ops on devices
  39. private bool log_device_placement = false;
  40. public void Run()
  41. {
  42. var (x_train, y_train, vocab_processor, x_dev, y_dev) = preprocess();
  43. }
  44. public (NDArray, NDArray, NDArray, NDArray, NDArray) preprocess()
  45. {
  46. DataHelpers.load_data_and_labels(positive_data_file, negative_data_file);
  47. throw new NotImplementedException("");
  48. }
  49. }
  50. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。