| @@ -10,7 +10,6 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class _ElementFetchMapper : _FetchMapper | |||
| { | |||
| private List<object> _unique_fetches = new List<object>(); | |||
| private Func<List<object>, object> _contraction_fn; | |||
| public _ElementFetchMapper(object[] fetches, Func<List<object>, object> contraction_fn) | |||
| @@ -32,7 +31,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="values"></param> | |||
| /// <returns></returns> | |||
| public NDArray build_results(List<object> values) | |||
| public override NDArray build_results(List<object> values) | |||
| { | |||
| NDArray result = null; | |||
| @@ -51,10 +50,5 @@ namespace Tensorflow | |||
| return result; | |||
| } | |||
| public List<object> unique_fetches() | |||
| { | |||
| return _unique_fetches; | |||
| } | |||
| } | |||
| } | |||
| @@ -10,7 +10,7 @@ namespace Tensorflow | |||
| /// </summary> | |||
| public class _FetchHandler | |||
| { | |||
| private _ElementFetchMapper _fetch_mapper; | |||
| private _FetchMapper _fetch_mapper; | |||
| private List<Tensor> _fetches = new List<Tensor>(); | |||
| private List<bool> _ops = new List<bool>(); | |||
| private List<Tensor> _final_fetches = new List<Tensor>(); | |||
| @@ -18,7 +18,7 @@ namespace Tensorflow | |||
| public _FetchHandler(Graph graph, object fetches, Dictionary<object, object> feeds = null, Action feed_handles = null) | |||
| { | |||
| _fetch_mapper = new _FetchMapper().for_fetch(fetches); | |||
| _fetch_mapper = _FetchMapper.for_fetch(fetches); | |||
| foreach(var fetch in _fetch_mapper.unique_fetches()) | |||
| { | |||
| switch (fetch) | |||
| @@ -58,7 +58,18 @@ namespace Tensorflow | |||
| { | |||
| var value = tensor_values[j]; | |||
| j += 1; | |||
| full_values.Add(value); | |||
| switch (value.dtype.Name) | |||
| { | |||
| case "Int32": | |||
| full_values.Add(value.Data<int>(0)); | |||
| break; | |||
| case "Single": | |||
| full_values.Add(value.Data<float>(0)); | |||
| break; | |||
| case "Double": | |||
| full_values.Add(value.Data<double>(0)); | |||
| break; | |||
| } | |||
| } | |||
| i += 1; | |||
| } | |||
| @@ -1,4 +1,5 @@ | |||
| using System; | |||
| using NumSharp.Core; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| @@ -6,14 +7,26 @@ namespace Tensorflow | |||
| { | |||
| public class _FetchMapper | |||
| { | |||
| public _ElementFetchMapper for_fetch(object fetch) | |||
| protected List<object> _unique_fetches = new List<object>(); | |||
| public static _FetchMapper for_fetch(object fetch) | |||
| { | |||
| var fetches = new object[] { fetch }; | |||
| var fetches = fetch.GetType().IsArray ? (object[])fetch : new object[] { fetch }; | |||
| if (fetch.GetType().IsArray) | |||
| return new _ListFetchMapper(fetches); | |||
| else | |||
| return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => fetched_vals[0]); | |||
| } | |||
| return new _ElementFetchMapper(fetches, (List<object> fetched_vals) => | |||
| { | |||
| return fetched_vals[0]; | |||
| }); | |||
| public virtual NDArray build_results(List<object> values) | |||
| { | |||
| return values.ToArray(); | |||
| } | |||
| public virtual List<object> unique_fetches() | |||
| { | |||
| return _unique_fetches; | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,18 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| { | |||
| public class _ListFetchMapper : _FetchMapper | |||
| { | |||
| private _FetchMapper[] _mappers; | |||
| public _ListFetchMapper(object[] fetches) | |||
| { | |||
| _mappers = fetches.Select(fetch => _FetchMapper.for_fetch(fetch)).ToArray(); | |||
| _unique_fetches.AddRange(fetches); | |||
| } | |||
| } | |||
| } | |||
| @@ -40,11 +40,7 @@ namespace TensorFlowNET.Examples | |||
| var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax | |||
| // Minimize error using cross entropy | |||
| var log = tf.log(pred); | |||
| var mul = y * log; | |||
| var sum = tf.reduce_sum(mul, reduction_indices: 1); | |||
| var neg = -sum; | |||
| var cost = tf.reduce_mean(neg); | |||
| var cost = tf.reduce_mean(-tf.reduce_sum(y * tf.log(pred), reduction_indices: 1)); | |||
| // Gradient Descent | |||
| var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost); | |||
| @@ -68,14 +64,23 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size); | |||
| // Run optimization op (backprop) and cost op (to get loss value) | |||
| var (_, c) = sess.run(optimizer, | |||
| var result = sess.run(new object[] { optimizer, cost }, | |||
| new FeedItem(x, batch_xs), | |||
| new FeedItem(y, batch_ys)); | |||
| var c = (float)result[1]; | |||
| // Compute average loss | |||
| avg_cost += c / total_batch; | |||
| } | |||
| // Display logs per epoch step | |||
| if ((epoch + 1) % display_step == 0) | |||
| print($"Epoch: {(epoch + 1).ToString("D4")} cost= {avg_cost.ToString("G9")}"); | |||
| } | |||
| print("Optimization Finished!"); | |||
| // Test model | |||
| }); | |||
| } | |||
| } | |||
| @@ -52,7 +52,28 @@ namespace TensorFlowNET.Examples.Utility | |||
| // Finished epoch | |||
| _epochs_completed += 1; | |||
| throw new NotImplementedException("next_batch"); | |||
| // Get the rest examples in this epoch | |||
| var rest_num_examples = _num_examples - start; | |||
| var images_rest_part = _images[np.arange(start, _num_examples)]; | |||
| var labels_rest_part = _labels[np.arange(start, _num_examples)]; | |||
| // Shuffle the data | |||
| if (shuffle) | |||
| { | |||
| var perm = np.arange(_num_examples); | |||
| np.random.shuffle(perm); | |||
| _images = images[perm]; | |||
| _labels = labels[perm]; | |||
| } | |||
| start = 0; | |||
| _index_in_epoch = batch_size - rest_num_examples; | |||
| var end = _index_in_epoch; | |||
| var images_new_part = _images[np.arange(start, end)]; | |||
| var labels_new_part = _labels[np.arange(start, end)]; | |||
| /*return (np.concatenate(new float[][] { images_rest_part.Data<float>(), images_new_part.Data<float>() }, axis: 0), | |||
| np.concatenate(new float[][] { labels_rest_part.Data<float>(), labels_new_part.Data<float>() }, axis: 0));*/ | |||
| return (images_new_part, labels_new_part); | |||
| } | |||
| else | |||
| { | |||