Browse Source

Support mutiple inputs of keras modek.predict.

tags/v0.100.5-BERT-load
Yaohui Liu Haiping 2 years ago
parent
commit
3db092b929
4 changed files with 18 additions and 24 deletions
  1. +1
    -2
      src/TensorFlowNET.Core/Data/OwnedIterator.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Keras/Engine/IModel.cs
  3. +3
    -3
      src/TensorFlowNET.Keras/Engine/Model.Predict.cs
  4. +13
    -18
      test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs

+ 1
- 2
src/TensorFlowNET.Core/Data/OwnedIterator.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow
IDatasetV2 _dataset; IDatasetV2 _dataset;
TensorSpec[] _element_spec; TensorSpec[] _element_spec;
dataset_ops ops = new dataset_ops(); dataset_ops ops = new dataset_ops();
Tensor _deleter;
//Tensor _deleter;
Tensor _iterator_resource; Tensor _iterator_resource;


public OwnedIterator(IDatasetV2 dataset) public OwnedIterator(IDatasetV2 dataset)
@@ -26,7 +26,6 @@ namespace Tensorflow
dataset = dataset.apply_options(); dataset = dataset.apply_options();
_dataset = dataset; _dataset = dataset;
_element_spec = dataset.element_spec; _element_spec = dataset.element_spec;
// _flat_output_types =
_iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes); _iterator_resource = ops.anonymous_iterator_v3(_dataset.output_types, _dataset.output_shapes);
// TODO(Rinne): deal with graph mode. // TODO(Rinne): deal with graph mode.
ops.make_iterator(dataset.variant_tensor, _iterator_resource); ops.make_iterator(dataset.variant_tensor, _iterator_resource);


+ 1
- 1
src/TensorFlowNET.Core/Keras/Engine/IModel.cs View File

@@ -62,7 +62,7 @@ public interface IModel : ILayer
bool use_multiprocessing = false, bool use_multiprocessing = false,
bool return_dict = false); bool return_dict = false);


Tensors predict(Tensor x,
Tensors predict(Tensors x,
int batch_size = -1, int batch_size = -1,
int verbose = 0, int verbose = 0,
int steps = -1, int steps = -1,


+ 3
- 3
src/TensorFlowNET.Keras/Engine/Model.Predict.cs View File

@@ -49,7 +49,7 @@ namespace Tensorflow.Keras.Engine
/// <param name="workers"></param> /// <param name="workers"></param>
/// <param name="use_multiprocessing"></param> /// <param name="use_multiprocessing"></param>
/// <returns></returns> /// <returns></returns>
public Tensors predict(Tensor x,
public Tensors predict(Tensors x,
int batch_size = -1, int batch_size = -1,
int verbose = 0, int verbose = 0,
int steps = -1, int steps = -1,
@@ -115,12 +115,12 @@ namespace Tensorflow.Keras.Engine
Tensors run_predict_step(OwnedIterator iterator) Tensors run_predict_step(OwnedIterator iterator)
{ {
var data = iterator.next(); var data = iterator.next();
var outputs = predict_step(data[0]);
var outputs = predict_step(data);
tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1)); tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1));
return outputs; return outputs;
} }


Tensors predict_step(Tensor data)
Tensors predict_step(Tensors data)
{ {
return Apply(data, training: false); return Apply(data, training: false);
} }


+ 13
- 18
test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs View File

@@ -1,27 +1,17 @@
using Microsoft.VisualStudio.TestPlatform.Utilities;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System; using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Xml.Linq;
using Tensorflow.Operations;
using static Tensorflow.Binding;
using static Tensorflow.KerasApi;
using Tensorflow.NumPy;
using Microsoft.VisualBasic;
using static HDF.PInvoke.H5T;
using Tensorflow.Keras.UnitTest.Helpers;
using Tensorflow;
using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Optimizers;
using Tensorflow.NumPy;
using static Tensorflow.KerasApi;


namespace Tensorflow.Keras.UnitTest
namespace TensorFlowNET.Keras.UnitTest
{ {
[TestClass] [TestClass]
public class MultiInputModelTest public class MultiInputModelTest
{ {
[TestMethod] [TestMethod]
public void SimpleModel()
public void LeNetModel()
{ {
var inputs = keras.Input((28, 28, 1)); var inputs = keras.Input((28, 28, 1));
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs); var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
@@ -40,7 +30,7 @@ namespace Tensorflow.Keras.UnitTest
var concat = keras.layers.Concatenate().Apply((flat1, flat1_2)); var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat); var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1); var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
var output = keras.layers.Softmax(-1).Apply(dense3); var output = keras.layers.Softmax(-1).Apply(dense3);


var model = keras.Model((inputs, inputs_2), output); var model = keras.Model((inputs, inputs_2), output);
@@ -52,7 +42,7 @@ namespace Tensorflow.Keras.UnitTest
{ {
TrainDir = "mnist", TrainDir = "mnist",
OneHot = false, OneHot = false,
ValidationSize = 59000,
ValidationSize = 59900,
}).Result; }).Result;


var loss = keras.losses.SparseCategoricalCrossentropy(); var loss = keras.losses.SparseCategoricalCrossentropy();
@@ -64,6 +54,11 @@ namespace Tensorflow.Keras.UnitTest


var x = new NDArray[] { x1, x2 }; var x = new NDArray[] { x1, x2 };
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3); model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);

x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
var pred = model.predict((x1, x2));
Console.WriteLine(pred);
} }
} }
} }

Loading…
Cancel
Save