diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 9ec1ef58..0bc9d0e5 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -75,7 +75,7 @@ namespace Tensorflow
///
/// A pair (outputs, state)
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
- int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
+ Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
=> rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
parallel_iterations: parallel_iterations, swap_memory: swap_memory,
diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
index 61da061f..554e9f1a 100644
--- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
@@ -24,7 +24,8 @@ namespace Tensorflow
int _num_units;
Func _activation;
- protected override int state_size => _num_units;
+ public override int state_size => _num_units;
+ public override int output_size => _num_units;
public BasicRNNCell(int num_units,
Func activation = null,
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 79d1df89..3200e13f 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -24,15 +24,15 @@ namespace Tensorflow.Operations
{
internal class rnn
{
- public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
- int? sequence_length = null, Tensor initial_state = null,
+ public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
+ Tensor sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
{
with(tf.variable_scope("rnn"), scope =>
{
VariableScope varscope = scope;
- var flat_input = nest.flatten(inputs);
+ var flat_input = nest.flatten(inputs_tensor);
if (!time_major)
{
@@ -42,24 +42,146 @@ namespace Tensorflow.Operations
parallel_iterations = parallel_iterations ?? 32;
- if (sequence_length.HasValue)
+ if (sequence_length != null)
throw new NotImplementedException("dynamic_rnn sequence_length has value");
var batch_size = _best_effort_input_batch_size(flat_input);
+ Tensor state = null;
if (initial_state != null)
- {
- var state = initial_state;
- }
+ state = initial_state;
else
+ state = cell.get_initial_state(batch_size: batch_size, dtype: dtype);
+
+ var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input);
+
+ var (outputs, final_state) = _dynamic_rnn_loop(
+ cell,
+ inputs as Tensor,
+ state,
+ parallel_iterations: parallel_iterations.Value,
+ swap_memory: swap_memory,
+ sequence_length: sequence_length,
+ dtype: dtype);
+ });
+
+ throw new NotImplementedException("");
+ }
+
+ ///
+ /// Internal implementation of Dynamic RNN.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state,
+ int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ var state = initial_state;
+ var state_size = cell.state_size;
+
+ var flat_input = nest.flatten(inputs);
+ var flat_output_size = nest.flatten(cell.output_size);
+
+ // Construct an initial output
+ var input_shape = array_ops.shape(flat_input[0]);
+ var time_steps = input_shape.slice(0);
+ var batch_size = _best_effort_input_batch_size(flat_input);
+ var inputs_got_shape = flat_input.Select(input_ => input_.TensorShape.with_rank_at_least(3)).ToArray();
+
+ var dims = inputs_got_shape[0].Dimensions.Take(2).ToArray();
+ var (const_time_steps, const_batch_size) = (dims[0], dims[1]);
+
+ foreach(var shape in inputs_got_shape)
+ {
+ if (shape[2] == -1)
+ throw new ValueError("Input size (depth of inputs) must be accessible via shape inference," +
+ " but saw value None.");
+
+ var got_time_steps = shape.dims[0];
+ var got_batch_size = shape.dims[1];
+
+ if (const_time_steps != got_time_steps)
+ throw new ValueError("Time steps is not the same for all the elements in the input in a " +
+ "batch.");
+
+ if (const_batch_size != got_batch_size)
+ throw new ValueError("Batch_size is not the same for all the elements in the input.");
+ }
+
+ Func _create_zero_arrays = (size_) =>
+ {
+ var size = rnn_cell_impl._concat(batch_size, size_);
+ return array_ops.zeros(
+ array_ops.stack(size), dtype: _infer_state_dtype(dtype, state));
+ };
+
+ // Prepare dynamic conditional copying of state & output
+ var flat_zero_output = flat_output_size.Select(output => _create_zero_arrays(output)).ToArray();
+ var zero_output = nest.pack_sequence_as(structure: cell.output_size, flat_sequence: flat_zero_output);
+
+ Tensor min_sequence_length = null, max_sequence_length = null;
+ if (sequence_length != null)
+ {
+ min_sequence_length = math_ops.reduce_min(sequence_length);
+ max_sequence_length = math_ops.reduce_max(sequence_length);
+ }
+ else
+ {
+ max_sequence_length = time_steps;
+ }
+
+ var time = array_ops.constant(0, dtype: dtypes.int32, name: "time");
+
+ string base_name = null;
+ with(ops.name_scope("dynamic_rnn"), scope => base_name = scope);
+
+ Func _create_ta = (name, element_shape, dtype_) =>
+ {
+ new TensorArray(dtype: dtype_,
+ size: time_steps,
+ element_shape: element_shape,
+ tensor_array_name: base_name + name);
+ throw new NotImplementedException("");
+ };
+
+ bool in_graph_mode = true;
+ if (in_graph_mode)
+ {
+ foreach(var (i, out_size) in enumerate(flat_output_size))
{
- cell.get_initial_state(batch_size: batch_size, dtype: dtype);
+ _create_ta($"output_{i}",
+ new TensorShape(const_batch_size).concatenate(
+ _maybe_tensor_shape_from_tensor(out_size)),
+ _infer_state_dtype(dtype, state));
+
+
+
}
- });
+ }
throw new NotImplementedException("");
}
+ private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape)
+ => shape.TensorShape;
+
+ private static TensorShape _maybe_tensor_shape_from_tensor(int shape)
+ => new TensorShape(shape);
+
+ private static TF_DataType _infer_state_dtype(TF_DataType explicit_dtype, Tensor state)
+ {
+ if (explicit_dtype != TF_DataType.DtInvalid)
+ return explicit_dtype;
+
+ throw new NotImplementedException("_infer_state_dtype");
+ }
+
///
/// Transposes the batch and time dimensions of a Tensor.
///
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
new file mode 100644
index 00000000..bd210ecd
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
@@ -0,0 +1,57 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+
+namespace Tensorflow.Operations
+{
+ public class rnn_cell_impl
+ {
+ public BasicRNNCell BasicRNNCell(int num_units)
+ => new BasicRNNCell(num_units);
+
+ public static Tensor _concat(Tensor prefix, int suffix, bool @static = false)
+ {
+ var p = prefix;
+ var p_static = tensor_util.constant_value(prefix);
+ if (p.NDims == 0)
+ p = array_ops.expand_dims(p, 0);
+ else if (p.NDims != 1)
+ throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}");
+
+ var s_tensor_shape = new TensorShape(suffix);
+ var s_static = s_tensor_shape.NDim > -1 ?
+ s_tensor_shape.Dimensions :
+ null;
+ var s = s_tensor_shape.is_fully_defined() ?
+ constant_op.constant(s_tensor_shape.Dimensions, dtype: dtypes.int32) :
+ null;
+
+ if (@static)
+ {
+ if (p_static is null) return null;
+ var shape = new TensorShape(p_static).concatenate(s_static);
+ throw new NotImplementedException("RNNCell _concat");
+ }
+ else
+ {
+ if (p is null || s is null)
+ throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}");
+ return array_ops.concat(new[] { p, s }, 0);
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs
index 3b841087..57f46e7b 100644
--- a/src/TensorFlowNET.Core/Operations/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System;
+using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Python;
@@ -48,7 +49,9 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell.
///
protected bool _is_tf_rnn_cell = false;
- protected virtual int state_size { get; }
+ public virtual int state_size { get; }
+
+ public virtual int output_size { get; }
public RNNCell(bool trainable = true,
string name = null,
@@ -89,12 +92,18 @@ namespace Tensorflow
private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype)
{
- nest.map_structure(x =>
+ var output = nest.map_structure(s =>
{
- throw new NotImplementedException("");
+ var c = rnn_cell_impl._concat(batch_size, s);
+ var size = array_ops.zeros(c, dtype: dtype);
+
+ var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
+ size.set_shape(c_static);
+
+ return size;
}, state_size);
- throw new NotImplementedException("");
+ return output;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs
new file mode 100644
index 00000000..858dac47
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs
@@ -0,0 +1,54 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ ///
+ /// TensorArray is designed to hide an underlying implementation object
+ /// and as such accesses many of that object's hidden fields.
+ ///
+ /// "Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
+ /// This class is meant to be used with dynamic iteration primitives such as
+ /// `while_loop` and `map_fn`. It supports gradient back-propagation via special
+ /// "flow" control flow dependencies.
+ ///
+ public class TensorArray
+ {
+ _GraphTensorArray _implementation;
+
+ public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null,
+ string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
+ bool infer_shape = true, TensorShape element_shape = null,
+ bool colocate_with_first_write_call = true, string name = null)
+ {
+ _implementation = new _GraphTensorArray(dtype,
+ size: size,
+ dynamic_size: dynamic_size,
+ clear_after_read: clear_after_read,
+ tensor_array_name: tensor_array_name,
+ handle: handle,
+ flow: flow,
+ infer_shape: infer_shape,
+ element_shape: element_shape,
+ colocate_with_first_write_call: colocate_with_first_write_call,
+ name: name);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
new file mode 100644
index 00000000..b4619c05
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
@@ -0,0 +1,102 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Python;
+
+namespace Tensorflow.Operations
+{
+ internal class _GraphTensorArray
+ {
+ TF_DataType _dtype;
+
+ ///
+ /// Used to keep track of what tensors the TensorArray should be
+ /// colocated with. We choose to colocate the TensorArray with the
+ /// first tensor written to it.
+ ///
+ bool _colocate_with_first_write_call;
+
+ bool _infer_shape;
+ List _element_shape;
+
+ object _colocate_with;
+
+ public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
+ bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
+ bool infer_shape = true, TensorShape element_shape = null,
+ bool colocate_with_first_write_call = true, string name = null)
+ {
+ clear_after_read = clear_after_read ?? true;
+ dynamic_size = dynamic_size ?? false;
+
+ _dtype = dtype;
+
+ _colocate_with_first_write_call = colocate_with_first_write_call;
+ if (colocate_with_first_write_call)
+ _colocate_with = new Tensor[0];
+
+ // Record the current static shape for the array elements. The element
+ // shape is defined either by `element_shape` or the shape of the tensor
+ // of the first write. If `infer_shape` is true, all writes checks for
+ // shape equality.
+ if(element_shape == null)
+ {
+ _infer_shape = infer_shape;
+ _element_shape = new List { };
+ }
+ else
+ {
+ _infer_shape = true;
+ _element_shape = new List { };
+ }
+
+ with(ops.name_scope(name, "", new { handle, size, flow }), scope =>
+ {
+ if(handle != null)
+ {
+
+ }
+ else
+ {
+ Func<(Tensor, Tensor)> create = () => gen_data_flow_ops.tensor_array_v3(size,
+ dtype: dtype,
+ element_shape: element_shape,
+ identical_element_shapes: infer_shape,
+ dynamic_size: dynamic_size.Value,
+ clear_after_read: clear_after_read.Value,
+ tensor_array_name: tensor_array_name,
+ name: scope);
+
+ // Construct the TensorArray with an empty device. The first
+ // write into the TensorArray from a Tensor with a set device
+ // will retroactively set the device value of this op.
+ if (colocate_with_first_write_call)
+ {
+ ops.colocate_with(ignore_existing: true);
+ create();
+ }
+ else
+ {
+
+ }
+ }
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 98b36bc6..c3f52cb8 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -29,6 +29,17 @@ namespace Tensorflow
public static Tensor prevent_gradient(Tensor input, string message = "", string name = null)
=> gen_array_ops.prevent_gradient(input, message: message, name: name);
+ internal static Tensor constant(object value,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ int[] shape = null,
+ string name = "Const",
+ bool verify_shape = false) => constant_op._constant_impl(value,
+ dtype,
+ shape,
+ name,
+ verify_shape: verify_shape,
+ allow_broadcast: false);
+
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
similarity index 63%
rename from src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs
rename to src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
index 17f64a29..2cb9aac6 100644
--- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
@@ -27,5 +27,23 @@ namespace Tensorflow
return _op.outputs[0];
}
+
+ public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid,
+ int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
+ bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
+ {
+ size,
+ dtype,
+ element_shape,
+ dynamic_size,
+ clear_after_read,
+ identical_element_shapes,
+ tensor_array_name
+ });
+
+ return (null, null);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
deleted file mode 100644
index 72f4b866..00000000
--- a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
+++ /dev/null
@@ -1,8 +0,0 @@
-namespace Tensorflow.Operations
-{
- public class rnn_cell_impl
- {
- public BasicRNNCell BasicRNNCell(int num_units)
- => new BasicRNNCell(num_units);
- }
-}
diff --git a/src/TensorFlowNET.Core/Tensors/Tensor.cs b/src/TensorFlowNET.Core/Tensors/Tensor.cs
index e7049e7e..aebca212 100644
--- a/src/TensorFlowNET.Core/Tensors/Tensor.cs
+++ b/src/TensorFlowNET.Core/Tensors/Tensor.cs
@@ -110,6 +110,11 @@ namespace Tensorflow
this.shape = shape.Dimensions;
}
+ public void set_shape(Tensor shape)
+ {
+ this.shape = shape is null ? null : shape.shape;
+ }
+
public int[] dims => shape;
///
diff --git a/src/TensorFlowNET.Core/Tensors/TensorShape.cs b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
index 8559cbd4..c19ecae7 100644
--- a/src/TensorFlowNET.Core/Tensors/TensorShape.cs
+++ b/src/TensorFlowNET.Core/Tensors/TensorShape.cs
@@ -9,6 +9,8 @@ namespace Tensorflow
///
public class TensorShape : Shape
{
+ public int[] dims => Dimensions;
+
public TensorShape(TensorShapeProto proto)
{
if (proto.UnknownRank) return;
@@ -45,6 +47,29 @@ namespace Tensorflow
throw new NotImplementedException("TensorShape is_compatible_with");
}
+ public TensorShape with_rank_at_least(int rank)
+ {
+ if (rank != this.NDim)
+ throw new ValueError($"Shape {this} must have rank at least {rank}");
+ else
+ return this;
+ }
+
+ ///
+ /// Returns the concatenation of the dimension in `self` and `other`.
+ ///
+ ///
+ ///
+ public TensorShape concatenate(int[] other_)
+ {
+ var other = new TensorShape(other_);
+
+ if (NDim < 0 || other.NDim < 0)
+ return new TensorShape();
+ else
+ return new TensorShape(NDim + other.NDim);
+ }
+
public static implicit operator TensorShape(int[] dims) => new TensorShape(dims);
public static implicit operator TensorShape((int, int) dims) => new TensorShape(dims.Item1, dims.Item2);
public static implicit operator TensorShape((int, int, int) dims) => new TensorShape(dims.Item1, dims.Item2, dims.Item3);
diff --git a/src/TensorFlowNET.Core/Util/nest.py.cs b/src/TensorFlowNET.Core/Util/nest.py.cs
index ae13b31c..5f782ba2 100644
--- a/src/TensorFlowNET.Core/Util/nest.py.cs
+++ b/src/TensorFlowNET.Core/Util/nest.py.cs
@@ -223,31 +223,27 @@ namespace Tensorflow.Util
private static void _flatten_recursive(T obj, List list)
{
- if (obj is string)
- {
- list.Add(obj);
- return;
- }
- if (obj is IDictionary)
- {
- var dict = obj as IDictionary;
- foreach (var key in _sorted(dict))
- _flatten_recursive((T)dict[key], list);
- return;
- }
- if (obj is NDArray)
- {
- list.Add(obj);
- return;
- }
- if (obj is IEnumerable)
+
+ switch(obj)
{
- var structure = obj as IEnumerable;
- foreach (var child in structure)
- _flatten_recursive((T)child, list);
- return;
+ case IDictionary dict:
+ foreach (var key in _sorted(dict))
+ _flatten_recursive((T)dict[key], list);
+ break;
+ case String str:
+ list.Add(obj);
+ break;
+ case NDArray nd:
+ list.Add(obj);
+ break;
+ case IEnumerable structure:
+ foreach (var child in structure)
+ _flatten_recursive((T)child, list);
+ break;
+ default:
+ list.Add(obj);
+ break;
}
- list.Add(obj);
}
diff --git a/src/TensorFlowNET.Core/ops.py.cs b/src/TensorFlowNET.Core/ops.py.cs
index 90cda74e..8f7fce29 100644
--- a/src/TensorFlowNET.Core/ops.py.cs
+++ b/src/TensorFlowNET.Core/ops.py.cs
@@ -314,6 +314,11 @@ namespace Tensorflow
return uid_number++;
}
+ public static void colocate_with(bool ignore_existing = false)
+ {
+ _colocate_with_for_gradient(null, null, ignore_existing);
+ }
+
public static void colocate_with(Operation op, bool ignore_existing = false)
{
_colocate_with_for_gradient(op, null, ignore_existing);
diff --git a/tensorflowlib/README.md b/tensorflowlib/README.md
index 77a78a66..63cba815 100644
--- a/tensorflowlib/README.md
+++ b/tensorflowlib/README.md
@@ -40,7 +40,7 @@ Before running verify you installed CUDA and cuDNN
https://www.tensorflow.org/install/source_windows
-pacman -S git patch unzip
+`pacman -S git patch unzip`
1. Build static library
diff --git a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
index 25ffc46a..2dc355c4 100644
--- a/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
+++ b/test/TensorFlowNET.Examples/ImageProcessing/DigitRecognitionCNN.cs
@@ -42,7 +42,7 @@ namespace TensorFlowNET.Examples.ImageProcess
int n_channels = 1;
// Hyper-parameters
- int epochs = 10;
+ int epochs = 5; // accuracy > 98%
int batch_size = 100;
float learning_rate = 0.001f;
Datasets mnist;
@@ -84,7 +84,7 @@ namespace TensorFlowNET.Examples.ImageProcess
Test(sess);
});
- return loss_test < 0.09 && accuracy_test > 0.95;
+ return loss_test < 0.05 && accuracy_test > 0.98;
}
public Graph BuildGraph()