| @@ -10,7 +10,6 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class _ElementFetchMapper : _FetchMapper | public class _ElementFetchMapper : _FetchMapper | ||||
| { | { | ||||
| private List<object> _unique_fetches = new List<object>(); | |||||
| private Func<List<object>, object> _contraction_fn; | private Func<List<object>, object> _contraction_fn; | ||||
| public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | ||||
| @@ -32,7 +31,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| /// <param name="values"></param> | /// <param name="values"></param> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| public NDArray build_results(List<object> values) | |||||
| public override NDArray build_results(List<object> values) | |||||
| { | { | ||||
| NDArray result = null; | NDArray result = null; | ||||
| @@ -51,10 +50,5 @@ namespace Tensorflow | |||||
| return result; | return result; | ||||
| } | } | ||||
| public List<object> unique_fetches() | |||||
| { | |||||
| return _unique_fetches; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||||
| /// </summary> | /// </summary> | ||||
| public class _FetchHandler | public class _FetchHandler | ||||
| { | { | ||||
| private _ElementFetchMapper _fetch_mapper; | |||||
| private _FetchMapper _fetch_mapper; | |||||
| private List<Tensor> _fetches = new List<Tensor>(); | private List<Tensor> _fetches = new List<Tensor>(); | ||||
| private List<bool> _ops = new List<bool>(); | private List<bool> _ops = new List<bool>(); | ||||
| private List<Tensor> _final_fetches = new List<Tensor>(); | private List<Tensor> _final_fetches = new List<Tensor>(); | ||||
| @@ -18,7 +18,7 @@ namespace Tensorflow | |||||
| public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | ||||
| { | { | ||||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||||
| _fetch_mapper = _FetchMapper.for_fetch(fetches); | |||||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | foreach(var fetch in _fetch_mapper.unique_fetches()) | ||||
| { | { | ||||
| switch (fetch) | switch (fetch) | ||||
| @@ -58,7 +58,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| var value = tensor_values[j]; | var value = tensor_values[j]; | ||||
| j += 1; | j += 1; | ||||
| full_values.Add(value); | |||||
| switch (value.dtype.Name) | |||||
| { | |||||
| case "Int32": | |||||
| full_values.Add(value.Data<int>(0)); | |||||
| break; | |||||
| case "Single": | |||||
| full_values.Add(value.Data<float>(0)); | |||||
| break; | |||||
| case "Double": | |||||
| full_values.Add(value.Data<double>(0)); | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| i += 1; | i += 1; | ||||
| } | } | ||||
| @@ -1,4 +1,5 @@ | |||||
| using System; | |||||
| using NumSharp.Core; | |||||
| using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Text; | using System.Text; | ||||
| @@ -6,14 +7,26 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class _FetchMapper | public class _FetchMapper | ||||
| { | { | ||||
| public _ElementFetchMapper for_fetch(object fetch) | |||||
| protected List<object> _unique_fetches = new List<object>(); | |||||
| public static _FetchMapper for_fetch(object fetch) | |||||
| { | { | ||||
| var fetches = new object[] { fetch }; | |||||
| var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; | |||||
| if (fetch.GetType().IsArray) | |||||
| return new _ListFetchMapper(fetches); | |||||
| else | |||||
| return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]); | |||||
| } | |||||
| return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||||
| { | |||||
| return fetched_vals[0]; | |||||
| }); | |||||
| public virtual NDArray build_results(List<object> values) | |||||
| { | |||||
| return values.ToArray(); | |||||
| } | |||||
| public virtual List<object> unique_fetches() | |||||
| { | |||||
| return _unique_fetches; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,18 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Linq; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public class _ListFetchMapper : _FetchMapper | |||||
| { | |||||
| private _FetchMapper[] _mappers; | |||||
| public _ListFetchMapper(object[] fetches) | |||||
| { | |||||
| _mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray(); | |||||
| _unique_fetches.AddRange(fetches); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -40,11 +40,7 @@ namespace TensorFlowNET.Examples | |||||
| var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax | var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax | ||||
| // Minimize error using cross entropy | // Minimize error using cross entropy | ||||
| var log = tf.log(pred); | |||||
| var mul = y * log; | |||||
| var sum = tf.reduce_sum(mul, reduction_indices: 1); | |||||
| var neg = -sum; | |||||
| var cost = tf.reduce_mean(neg); | |||||
| var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1)); | |||||
| // Gradient Descent | // Gradient Descent | ||||
| var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | ||||
| @@ -68,14 +64,23 @@ namespace TensorFlowNET.Examples | |||||
| { | { | ||||
| var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); | var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); | ||||
| // Run optimization op (backprop) and cost op (to get loss value) | // Run optimization op (backprop) and cost op (to get loss value) | ||||
| var (_, c) = sess.run(optimizer, | |||||
| var result = sess.run(new object[] { optimizer, cost }, | |||||
| new FeedItem(x, batch_xs), | new FeedItem(x, batch_xs), | ||||
| new FeedItem(y, batch_ys)); | new FeedItem(y, batch_ys)); | ||||
| var c = (float)result[1]; | |||||
| // Compute average loss | // Compute average loss | ||||
| avg_cost += c / total_batch; | avg_cost += c / total_batch; | ||||
| } | } | ||||
| // Display logs per epoch step | |||||
| if ((epoch + 1) % display_step == 0) | |||||
| print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}"); | |||||
| } | } | ||||
| print("Optimization Finished!"); | |||||
| // Test model | |||||
| }); | }); | ||||
| } | } | ||||
| } | } | ||||
| @@ -52,7 +52,28 @@ namespace TensorFlowNET.Examples.Utility | |||||
| // Finished epoch | // Finished epoch | ||||
| _epochs_completed += 1; | _epochs_completed += 1; | ||||
| throw new NotImplementedException("next_batch"); | |||||
| // Get the rest examples in this epoch | |||||
| var rest_num_examples = _num_examples - start; | |||||
| var images_rest_part = _images[np.arange(start, _num_examples)]; | |||||
| var labels_rest_part = _labels[np.arange(start, _num_examples)]; | |||||
| // Shuffle the data | |||||
| if (shuffle) | |||||
| { | |||||
| var perm = np.arange(_num_examples); | |||||
| np.random.shuffle(perm); | |||||
| _images = images[perm]; | |||||
| _labels = labels[perm]; | |||||
| } | |||||
| start = 0; | |||||
| _index_in_epoch = batch_size - rest_num_examples; | |||||
| var end = _index_in_epoch; | |||||
| var images_new_part = _images[np.arange(start, end)]; | |||||
| var labels_new_part = _labels[np.arange(start, end)]; | |||||
| /*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0), | |||||
| np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/ | |||||
| return (images_new_part, labels_new_part); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||