Browse Source

2st stage

tags/v0.12
Oceania2018 6 years ago
parent
commit
3f9b0b19ab
10 changed files with 22 additions and 20 deletions
  1. +5
    -3
      src/TensorFlowNET.Core/Sessions/_FetchHandler.cs
  2. +2
    -2
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  3. +4
    -4
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  4. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
  5. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs
  6. +1
    -1
      test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs
  7. +4
    -4
      test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs
  8. +1
    -1
      test/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs
  9. +2
    -2
      test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs
  10. +1
    -1
      test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs

+ 5
- 3
src/TensorFlowNET.Core/Sessions/_FetchHandler.cs View File

@@ -111,9 +111,11 @@ namespace Tensorflow
case "Double":
full_values.Add(value.Data<double>()[0]);
break;
case "String":
full_values.Add(value.Data<string>()[0]);
break;
/*case "String":
full_values.Add(value.Data<byte>()[0]);
break;*/
default:
throw new NotImplementedException($"build_results tensor_values[0] {tensor_values[0].dtype.Name}");
}
}
else


+ 2
- 2
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -296,9 +296,9 @@ namespace Tensorflow
case "Double":
tensor_proto.DoubleVal.AddRange(proto_values.Data<double>());
break;
case "String":
/*case "String":
tensor_proto.StringVal.AddRange(proto_values.Data<string>().Select(x => Google.Protobuf.ByteString.CopyFromUtf8(x.ToString())));
break;
break;*/
default:
throw new Exception("make_tensor_proto Not Implemented");
}


+ 4
- 4
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -170,7 +170,7 @@ namespace Tensorflow
{
if (string.IsNullOrEmpty(latest_filename))
latest_filename = "checkpoint";
string model_checkpoint_path = "";
object model_checkpoint_path = "";
string checkpoint_file = "";

if (global_step > 0)
@@ -188,10 +188,10 @@ namespace Tensorflow

if (write_state)
{
_RecordLastCheckpoint(model_checkpoint_path);
_RecordLastCheckpoint(model_checkpoint_path.ToString());
checkpoint_management.update_checkpoint_state_internal(
save_dir: save_path_parent,
model_checkpoint_path: model_checkpoint_path,
model_checkpoint_path: model_checkpoint_path.ToString(),
all_model_checkpoint_paths: _last_checkpoints.Keys.Select(x => x).ToList(),
latest_filename: latest_filename,
save_relative_paths: _save_relative_paths);
@@ -205,7 +205,7 @@ namespace Tensorflow
export_meta_graph(meta_graph_filename, strip_default_attrs: strip_default_attrs, save_debug_info: save_debug_info);
}

return _is_empty ? string.Empty : model_checkpoint_path;
return _is_empty ? string.Empty : model_checkpoint_path.ToString();
}

public (Saver, object) import_meta_graph(string meta_graph_or_file,


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs View File

@@ -136,7 +136,7 @@ namespace TensorFlowNET.Examples.ImageProcess
public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = y_train.len / batch_size;
var num_tr_iter = y_train.shape[0] / batch_size;

var init = tf.global_variables_initializer();
sess.run(init);


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionNN.cs View File

@@ -127,7 +127,7 @@ namespace TensorFlowNET.Examples.ImageProcess
public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = mnist.train.labels.len / batch_size;
var num_tr_iter = mnist.train.labels.shape[0] / batch_size;

var init = tf.global_variables_initializer();
sess.run(init);


+ 1
- 1
test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionRNN.cs View File

@@ -88,7 +88,7 @@ namespace TensorFlowNET.Examples.ImageProcess
public void Train(Session sess)
{
// Number of training iterations in each epoch
var num_tr_iter = y_train.len / batch_size;
var num_tr_iter = y_train.shape[0] / batch_size;

var init = tf.global_variables_initializer();
sess.run(init);


+ 4
- 4
test/TensorFlowNET.Examples/ImageProcessing/ObjectDetection.cs View File

@@ -77,7 +77,7 @@ namespace TensorFlowNET.Examples

var results = sess.run(outTensorArr, new FeedItem(imgTensor, imgArr));

NDArray[] resultArr = results.Data<NDArray>();
NDArray[] resultArr = results.GetNDArrays();

buildOutputImage(resultArr);
}
@@ -119,14 +119,14 @@ namespace TensorFlowNET.Examples
// get bitmap
Bitmap bitmap = new Bitmap(Path.Join(imageDir, "input.jpg"));

float[] scores = resultArr[2].Data<float>();
float[] scores = resultArr[2].GetData<float>().ToArray();

for (int i=0; i<scores.Length; i++)
{
float score = scores[i];
if (score > MIN_SCORE)
{
float[] boxes = resultArr[1].Data<float>();
float[] boxes = resultArr[1].GetData<float>().ToArray();
float top = boxes[i * 4] * bitmap.Height;
float left = boxes[i * 4 + 1] * bitmap.Width;
float bottom = boxes[i * 4 + 2] * bitmap.Height;
@@ -140,7 +140,7 @@ namespace TensorFlowNET.Examples
Height = (int)(bottom - top)
};

float[] ids = resultArr[3].Data<float>();
float[] ids = resultArr[3].GetData<float>().ToArray();

string name = pbTxtItems.items.Where(w => w.id == (int)ids[i]).Select(s=>s.display_name).FirstOrDefault();



+ 1
- 1
test/TensorFlowNET.Examples/TextProcessing/BinaryTextClassification.cs View File

@@ -27,7 +27,7 @@ namespace TensorFlowNET.Examples
{
PrepareData();

Console.WriteLine($"Training entries: {train_data.len}, labels: {train_labels.len}");
Console.WriteLine($"Training entries: {train_data.shape[0]}, labels: {train_labels.shape[0]}");

// A dictionary mapping words to an integer index
var word_index = GetWordIndex();


+ 2
- 2
test/TensorFlowNET.Examples/TextProcessing/CnnTextClassification.cs View File

@@ -147,8 +147,8 @@ namespace TensorFlowNET.Examples
Console.WriteLine("\tDONE ");

var (train_x, valid_x, train_y, valid_y) = train_test_split(x, y, test_size: 0.15f);
Console.WriteLine("Training set size: " + train_x.len);
Console.WriteLine("Test set size: " + valid_x.len);
Console.WriteLine("Training set size: " + train_x.shape[0]);
Console.WriteLine("Test set size: " + valid_x.shape[0]);
}

public Graph ImportGraph()


+ 1
- 1
test/TensorFlowNET.Examples/TextProcessing/Word2Vec.cs View File

@@ -120,7 +120,7 @@ namespace TensorFlowNET.Examples
// Generate training batch for the skip-gram model
private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window)
{
var batch = np.ndarray((batch_size), dtype: np.int32);
var batch = np.ndarray(new Shape(batch_size), dtype: np.int32);
var labels = np.ndarray((batch_size, 1), dtype: np.int32);
// get window size (words left and right + current one)
int span = 2 * skip_window + 1;


Loading…
Cancel
Save