| @@ -54,6 +54,34 @@ namespace Tensorflow | |||
| status.Check(true); | |||
| } | |||
| public virtual void run(Operation op, params FeedItem[] feed_dict) | |||
| { | |||
| _run(op, feed_dict); | |||
| } | |||
| public virtual NDArray run(Tensor fetche, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetche, feed_dict)[0]; | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3, fetches.Item4 }, feed_dict); | |||
| return (results[0], results[1], results[2], results[3]); | |||
| } | |||
| public virtual (NDArray, NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2, fetches.Item3 }, feed_dict); | |||
| return (results[0], results[1], results[2]); | |||
| } | |||
| public virtual (NDArray, NDArray) run((ITensorOrOperation, ITensorOrOperation) fetches, params FeedItem[] feed_dict) | |||
| { | |||
| var results = _run(new object[] { fetches.Item1, fetches.Item2 }, feed_dict); | |||
| return (results[0], results[1]); | |||
| } | |||
| public virtual NDArray[] run(object fetches, params FeedItem[] feed_dict) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| @@ -13,5 +13,8 @@ | |||
| Key = key; | |||
| Value = val; | |||
| } | |||
| public static implicit operator FeedItem((object, object) feed) | |||
| => new FeedItem(feed.Item1, feed.Item2); | |||
| } | |||
| } | |||
| @@ -377,7 +377,7 @@ namespace Tensorflow | |||
| "`eval(session=sess)`."); | |||
| } | |||
| return session.run(tensor, feed_dict)[0]; | |||
| return session.run(tensor, feed_dict); | |||
| } | |||
| /// <summary> | |||
| @@ -91,16 +91,16 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| var c = sess.run(cost, | |||
| new FeedItem(X, train_X), | |||
| new FeedItem(Y, train_Y))[0]; | |||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)[0]} b={sess.run(b)[0]}"); | |||
| new FeedItem(Y, train_Y)); | |||
| Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + $"W={sess.run(W)} b={sess.run(b)}"); | |||
| } | |||
| } | |||
| Console.WriteLine("Optimization Finished!"); | |||
| var training_cost = sess.run(cost, | |||
| new FeedItem(X, train_X), | |||
| new FeedItem(Y, train_Y))[0]; | |||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)[0]} b={sess.run(b)[0]}"); | |||
| new FeedItem(Y, train_Y)); | |||
| Console.WriteLine($"Training cost={training_cost} W={sess.run(W)} b={sess.run(b)}"); | |||
| // Testing example | |||
| var test_X = np.array(6.83f, 4.668f, 8.9f, 7.91f, 5.7f, 8.7f, 3.1f, 2.1f); | |||
| @@ -108,7 +108,7 @@ namespace TensorFlowNET.Examples | |||
| Console.WriteLine("Testing... (Mean square loss Comparison)"); | |||
| var testing_cost = sess.run(tf.reduce_sum(tf.pow(pred - Y, 2.0f)) / (2.0f * test_X.shape[0]), | |||
| new FeedItem(X, test_X), | |||
| new FeedItem(Y, test_Y))[0]; | |||
| new FeedItem(Y, test_Y)); | |||
| Console.WriteLine($"Testing cost={testing_cost}"); | |||
| var diff = Math.Abs((float)training_cost - (float)testing_cost); | |||
| Console.WriteLine($"Absolute mean square loss difference: {diff}"); | |||
| @@ -90,11 +90,10 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| var (batch_xs, batch_ys) = mnist.Train.GetNextBatch(batch_size); | |||
| // Run optimization op (backprop) and cost op (to get loss value) | |||
| var result = sess.run(new object[] { optimizer, cost }, | |||
| new FeedItem(x, batch_xs), | |||
| new FeedItem(y, batch_ys)); | |||
| (_, float c) = sess.run((optimizer, cost), | |||
| (x, batch_xs), | |||
| (y, batch_ys)); | |||
| float c = result[1]; | |||
| // Compute average loss | |||
| avg_cost += c / total_batch; | |||
| } | |||
| @@ -115,7 +114,7 @@ namespace TensorFlowNET.Examples | |||
| var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)); | |||
| // Calculate accuracy | |||
| var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)); | |||
| float acc = accuracy.eval(new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); | |||
| float acc = accuracy.eval((x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
| print($"Accuracy: {acc.ToString("F4")}"); | |||
| return acc > 0.9; | |||
| @@ -64,7 +64,7 @@ namespace TensorFlowNET.Examples | |||
| foreach(int i in range(Xte.shape[0])) | |||
| { | |||
| // Get nearest neighbor | |||
| long nn_index = sess.run(pred, new FeedItem(xtr, Xtr), new FeedItem(xte, Xte[i]))[0]; | |||
| long nn_index = sess.run(pred, (xtr, Xtr), (xte, Xte[i])); | |||
| // Get nearest neighbor class label and compare it to its true label | |||
| int index = (int)nn_index; | |||
| @@ -72,7 +72,7 @@ namespace TensorFlowNET.Examples | |||
| print($"Test {i} Prediction: {np.argmax(Ytr[index])} True Class: {np.argmax(Yte[i])}"); | |||
| // Calculate accuracy | |||
| if ((int)np.argmax(Ytr[index]) == (int)np.argmax(Yte[i])) | |||
| if (np.argmax(Ytr[index]) == np.argmax(Yte[i])) | |||
| accuracy += 1f/ Xte.shape[0]; | |||
| } | |||
| @@ -103,10 +103,8 @@ namespace TensorFlowNET.Examples | |||
| // [train_op, gs, loss], | |||
| // feed_dict={features: xy, labels: y_} | |||
| // ) | |||
| var result = sess.run(new ITensorOrOperation[] { train_op, global_step, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); | |||
| loss_value = result[2]; | |||
| step = result[1]; | |||
| if (step % 1000 == 0) | |||
| (_, step, loss_value) = sess.run((train_op, global_step, loss), (features, data), (labels, y_)); | |||
| if (step == 1 || step % 1000 == 0) | |||
| Console.WriteLine($"Step {step} loss: {loss_value}"); | |||
| } | |||
| Console.WriteLine($"Final loss: {loss_value}"); | |||
| @@ -136,10 +134,8 @@ namespace TensorFlowNET.Examples | |||
| var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32); | |||
| while (step < num_steps) | |||
| { | |||
| var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); | |||
| loss_value = result[2]; | |||
| step = result[1]; | |||
| if (step % 1000 == 0) | |||
| (_, step, loss_value) = sess.run((train_op, gs, loss), (features, data), (labels, y_)); | |||
| if (step == 1 || step % 1000 == 0) | |||
| Console.WriteLine($"Step {step} loss: {loss_value}"); | |||
| } | |||
| Console.WriteLine($"Final loss: {loss_value}"); | |||
| @@ -53,8 +53,8 @@ namespace TensorFlowNET.Examples | |||
| new FeedItem(b, (short)3) | |||
| }; | |||
| // Run every operation with variable input | |||
| Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)[0]}"); | |||
| Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)[0]}"); | |||
| Console.WriteLine($"Addition with variables: {sess.run(add, feed_dict)}"); | |||
| Console.WriteLine($"Multiplication with variables: {sess.run(mul, feed_dict)}"); | |||
| } | |||
| // ---------------- | |||
| @@ -91,7 +91,7 @@ namespace TensorFlowNET.Examples | |||
| // The output of the op is returned in 'result' as a numpy `ndarray` object. | |||
| using (sess = tf.Session()) | |||
| { | |||
| var result = sess.run(product)[0]; | |||
| var result = sess.run(product); | |||
| Console.WriteLine(result.ToString()); // ==> [[ 12.]] | |||
| }; | |||
| @@ -136,7 +136,7 @@ namespace TensorFlowNET.Examples | |||
| var checkTensor = np.array<float>(0, 6, 0, 15, 0, 24, 3, 1, 6, 4, 9, 7, 6, 0, 15, 0, 24, 0); | |||
| using (var sess = tf.Session()) | |||
| { | |||
| var result = sess.run(batchMul)[0]; | |||
| var result = sess.run(batchMul); | |||
| Console.WriteLine(result.ToString()); | |||
| // | |||
| // ==> array([[[0, 6], | |||
| @@ -28,7 +28,7 @@ namespace TensorFlowNET.Examples | |||
| using (var sess = tf.Session()) | |||
| { | |||
| // Run the op | |||
| var result = sess.run(hello)[0]; | |||
| var result = sess.run(hello); | |||
| Console.WriteLine(result.ToString()); | |||
| return result.ToString().Equals(str); | |||
| } | |||
| @@ -160,7 +160,7 @@ namespace TensorFlowNET.Examples | |||
| var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); | |||
| // Run optimization op (backprop) | |||
| sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); | |||
| sess.run(optimizer, (x, x_batch), (y, y_batch)); | |||
| if (iteration % display_freq == 0) | |||
| { | |||
| @@ -174,9 +174,7 @@ namespace TensorFlowNET.Examples | |||
| } | |||
| // Run validation after every epoch | |||
| var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_valid), new FeedItem(y, y_valid)); | |||
| loss_val = results1[0]; | |||
| accuracy_val = results1[1]; | |||
| (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_valid), (y, y_valid)); | |||
| print("---------------------------------------------------------"); | |||
| print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); | |||
| print("---------------------------------------------------------"); | |||
| @@ -185,9 +183,7 @@ namespace TensorFlowNET.Examples | |||
| public void Test(Session sess) | |||
| { | |||
| var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_test), new FeedItem(y, y_test)); | |||
| loss_test = result[0]; | |||
| accuracy_test = result[1]; | |||
| (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, x_test), (y, y_test)); | |||
| print("---------------------------------------------------------"); | |||
| print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); | |||
| print("---------------------------------------------------------"); | |||
| @@ -148,23 +148,18 @@ namespace TensorFlowNET.Examples | |||
| var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); | |||
| // Run optimization op (backprop) | |||
| sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); | |||
| sess.run(optimizer, (x, x_batch), (y, y_batch)); | |||
| if (iteration % display_freq == 0) | |||
| { | |||
| // Calculate and display the batch loss and accuracy | |||
| var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); | |||
| loss_val = result[0]; | |||
| accuracy_val = result[1]; | |||
| (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, x_batch), (y, y_batch)); | |||
| print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); | |||
| } | |||
| } | |||
| // Run validation after every epoch | |||
| var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Validation.Data), new FeedItem(y, mnist.Validation.Labels)); | |||
| loss_val = results1[0]; | |||
| accuracy_val = results1[1]; | |||
| (loss_val, accuracy_val) = sess.run((loss, accuracy), (x, mnist.Validation.Data), (y, mnist.Validation.Labels)); | |||
| print("---------------------------------------------------------"); | |||
| print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); | |||
| print("---------------------------------------------------------"); | |||
| @@ -173,9 +168,7 @@ namespace TensorFlowNET.Examples | |||
| public void Test(Session sess) | |||
| { | |||
| var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.Test.Data), new FeedItem(y, mnist.Test.Labels)); | |||
| loss_test = result[0]; | |||
| accuracy_test = result[1]; | |||
| (loss_test, accuracy_test) = sess.run((loss, accuracy), (x, mnist.Test.Data), (y, mnist.Test.Labels)); | |||
| print("---------------------------------------------------------"); | |||
| print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); | |||
| print("---------------------------------------------------------"); | |||
| @@ -51,7 +51,7 @@ namespace TensorFlowNET.Examples | |||
| { | |||
| sw.Restart(); | |||
| var results = sess.run(output_operation.outputs[0], new FeedItem(input_operation.outputs[0], nd))[0]; | |||
| var results = sess.run(output_operation.outputs[0], (input_operation.outputs[0], nd)); | |||
| results = np.squeeze(results); | |||
| int idx = np.argmax(results); | |||
| @@ -81,7 +81,7 @@ namespace TensorFlowNET.Examples | |||
| var normalized = tf.divide(sub, new float[] { input_std }); | |||
| using (var sess = tf.Session(graph)) | |||
| return sess.run(normalized)[0]; | |||
| return sess.run(normalized); | |||
| } | |||
| public void PrepareData() | |||
| @@ -108,7 +108,7 @@ namespace TensorFlowNET.Examples | |||
| var dims_expander = tf.expand_dims(casted, 0); | |||
| using (var sess = tf.Session(graph)) | |||
| return sess.run(dims_expander)[0]; | |||
| return sess.run(dims_expander); | |||
| } | |||
| private void buildOutputImage(NDArray[] resultArr) | |||
| @@ -124,13 +124,13 @@ namespace TensorFlowNET.Examples | |||
| var (eval_session, _, bottleneck_input, ground_truth_input, evaluation_step, | |||
| prediction) = build_eval_session(class_count); | |||
| var results = eval_session.run(new Tensor[] { evaluation_step, prediction }, | |||
| new FeedItem(bottleneck_input, test_bottlenecks), | |||
| new FeedItem(ground_truth_input, test_ground_truth)); | |||
| (float accuracy, NDArray prediction1) = eval_session.run((evaluation_step, prediction), | |||
| (bottleneck_input, test_bottlenecks), | |||
| (ground_truth_input, test_ground_truth)); | |||
| print($"final test accuracy: {((float)results[0] * 100).ToString("G4")}% (N={len(test_bottlenecks)})"); | |||
| print($"final test accuracy: {(accuracy * 100).ToString("G4")}% (N={len(test_bottlenecks)})"); | |||
| return (results[0], results[1]); | |||
| return (accuracy, prediction1); | |||
| } | |||
| private (Session, Tensor, Tensor, Tensor, Tensor, Tensor) | |||
| @@ -661,11 +661,9 @@ namespace TensorFlowNET.Examples | |||
| bool is_last_step = (i + 1 == how_many_training_steps); | |||
| if ((i % eval_step_interval) == 0 || is_last_step) | |||
| { | |||
| results = sess.run( | |||
| new Tensor[] { evaluation_step, cross_entropy }, | |||
| new FeedItem(bottleneck_input, train_bottlenecks), | |||
| new FeedItem(ground_truth_input, train_ground_truth)); | |||
| (float train_accuracy, float cross_entropy_value) = (results[0], results[1]); | |||
| (float train_accuracy, float cross_entropy_value) = sess.run((evaluation_step, cross_entropy), | |||
| (bottleneck_input, train_bottlenecks), | |||
| (ground_truth_input, train_ground_truth)); | |||
| print($"{DateTime.Now}: Step {i + 1}: Train accuracy = {train_accuracy * 100}%, Cross entropy = {cross_entropy_value.ToString("G4")}"); | |||
| var (validation_bottlenecks, validation_ground_truth, _) = get_random_cached_bottlenecks( | |||
| @@ -676,12 +674,10 @@ namespace TensorFlowNET.Examples | |||
| // Run a validation step and capture training summaries for TensorBoard | |||
| // with the `merged` op. | |||
| results = sess.run(new Tensor[] { merged, evaluation_step }, | |||
| new FeedItem(bottleneck_input, validation_bottlenecks), | |||
| new FeedItem(ground_truth_input, validation_ground_truth)); | |||
| (_, float validation_accuracy) = sess.run((merged, evaluation_step), | |||
| (bottleneck_input, validation_bottlenecks), | |||
| (ground_truth_input, validation_ground_truth)); | |||
| //(string validation_summary, float validation_accuracy) = (results[0], results[1]); | |||
| float validation_accuracy = results[1]; | |||
| // validation_writer.add_summary(validation_summary, i); | |||
| print($"{DateTime.Now}: Step {i + 1}: Validation accuracy = {validation_accuracy * 100}% (N={len(validation_bottlenecks)}) {sw.ElapsedMilliseconds}ms"); | |||
| sw.Restart(); | |||
| @@ -741,10 +737,10 @@ namespace TensorFlowNET.Examples | |||
| using (var sess = tf.Session(graph)) | |||
| { | |||
| var result = sess.run(output, new FeedItem(input, fileBytes)); | |||
| var result = sess.run(output, (input, fileBytes)); | |||
| var prob = np.squeeze(result); | |||
| var idx = np.argmax(prob); | |||
| print($"Prediction result: [{labels[idx]} {prob[idx][0]}] for {img_path}."); | |||
| print($"Prediction result: [{labels[idx]} {prob[idx]}] for {img_path}."); | |||
| } | |||
| } | |||
| @@ -213,22 +213,13 @@ namespace TensorFlowNET.Examples | |||
| Tensor global_step = graph.OperationByName("Variable"); | |||
| Tensor accuracy = graph.OperationByName("accuracy/accuracy"); | |||
| stopwatch = Stopwatch.StartNew(); | |||
| int i = 0; | |||
| int step = 0; | |||
| foreach (var (x_batch, y_batch, total) in train_batches) | |||
| { | |||
| i++; | |||
| var train_feed_dict = new FeedDict | |||
| { | |||
| [model_x] = x_batch, | |||
| [model_y] = y_batch, | |||
| [is_training] = true, | |||
| }; | |||
| var result = sess.run(new ITensorOrOperation[] { optimizer, global_step, loss }, train_feed_dict); | |||
| loss_value = result[2]; | |||
| var step = (int)result[1]; | |||
| if (step % 10 == 0) | |||
| Console.WriteLine($"Training on batch {i}/{total} loss: {loss_value.ToString("0.0000")}."); | |||
| (_, step, loss_value) = sess.run((optimizer, global_step, loss), | |||
| (model_x, x_batch), (model_y, y_batch), (is_training, true)); | |||
| if (step == 1 || step % 10 == 0) | |||
| Console.WriteLine($"Training on batch {step}/{total} loss: {loss_value.ToString("0.0000")}."); | |||
| if (step % 100 == 0) | |||
| { | |||
| @@ -243,8 +234,7 @@ namespace TensorFlowNET.Examples | |||
| [model_y] = valid_y_batch, | |||
| [is_training] = false | |||
| }; | |||
| var result1 = sess.run(accuracy, valid_feed_dict); | |||
| float accuracy_value = result1[0]; | |||
| float accuracy_value = sess.run(accuracy, (model_x, valid_x_batch), (model_y, valid_y_batch), (is_training, false)); | |||
| sum_accuracy += accuracy_value; | |||
| cnt += 1; | |||
| } | |||
| @@ -80,17 +80,16 @@ namespace TensorFlowNET.Examples.Text.NER | |||
| private float run_epoch(Session sess, CoNLLDataset train, CoNLLDataset dev, int epoch) | |||
| { | |||
| NDArray[] results = null; | |||
| float accuracy = 0; | |||
| // iterate over dataset | |||
| var batches = minibatches(train, hp.batch_size); | |||
| foreach (var(words, labels) in batches) | |||
| { | |||
| var (fd, _) = get_feed_dict(words, labels, hp.lr, hp.dropout); | |||
| results = sess.run(new ITensorOrOperation[] { train_op, loss }, feed_dict: fd); | |||
| (_, accuracy) = sess.run((train_op, loss), feed_dict: fd); | |||
| } | |||
| return results[1]; | |||
| return accuracy; | |||
| } | |||
| private IEnumerable<((int[][], int[])[], int[][])> minibatches(CoNLLDataset data, int minibatch_size) | |||
| @@ -81,8 +81,8 @@ namespace TensorFlowNET.Examples | |||
| // Get a new batch of data | |||
| var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window); | |||
| var result = sess.run(new ITensorOrOperation[] { train_op, loss_op }, new FeedItem(X, batch_x), new FeedItem(Y, batch_y)); | |||
| average_loss += result[1]; | |||
| (_, float loss) = sess.run((train_op, loss_op), (X, batch_x), (Y, batch_y)); | |||
| average_loss += loss; | |||
| if (step % display_step == 0 || step == 1) | |||
| { | |||
| @@ -97,7 +97,7 @@ namespace TensorFlowNET.Examples | |||
| if (step % eval_step == 0 || step == 1) | |||
| { | |||
| print("Evaluation..."); | |||
| var sim = sess.run(cosine_sim_op, new FeedItem(X, x_test))[0]; | |||
| var sim = sess.run(cosine_sim_op, (X, x_test)); | |||
| foreach(var i in range(len(eval_words))) | |||
| { | |||
| var nearest = (0f - sim[i]).argsort<float>() | |||