| @@ -18,7 +18,7 @@ namespace TensorFlowNET.Examples | |||||
| public string Name => "NN XOR"; | public string Name => "NN XOR"; | ||||
| public bool ImportGraph { get; set; } = false; | public bool ImportGraph { get; set; } = false; | ||||
| public int num_steps = 5000; | |||||
| public int num_steps = 10000; | |||||
| private NDArray data; | private NDArray data; | ||||
| @@ -55,7 +55,7 @@ namespace TensorFlowNET.Examples | |||||
| if (ImportGraph) | if (ImportGraph) | ||||
| loss_value = RunWithImportedGraph(); | loss_value = RunWithImportedGraph(); | ||||
| else | else | ||||
| loss_value=RunWithBuiltGraph(); | |||||
| loss_value = RunWithBuiltGraph(); | |||||
| return loss_value < 0.0628; | return loss_value < 0.0628; | ||||
| } | } | ||||
| @@ -96,6 +96,7 @@ namespace TensorFlowNET.Examples | |||||
| } | } | ||||
| Console.WriteLine($"Final loss: {loss_value}"); | Console.WriteLine($"Final loss: {loss_value}"); | ||||
| }); | }); | ||||
| return loss_value; | return loss_value; | ||||
| } | } | ||||
| @@ -120,11 +121,6 @@ namespace TensorFlowNET.Examples | |||||
| var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32); | var y_ = np.array(new int[] { 1, 0, 0, 1 }, dtype: np.int32); | ||||
| while (step < num_steps) | while (step < num_steps) | ||||
| { | { | ||||
| // original python: | |||||
| //_, step, loss_value = sess.run( | |||||
| // [train_op, gs, loss], | |||||
| // feed_dict={features: xy, labels: y_} | |||||
| // ) | |||||
| var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); | var result = sess.run(new ITensorOrOperation[] { train_op, gs, loss }, new FeedItem(features, data), new FeedItem(labels, y_)); | ||||
| loss_value = result[2]; | loss_value = result[2]; | ||||
| //step = result[1]; | //step = result[1]; | ||||
| @@ -134,6 +130,7 @@ namespace TensorFlowNET.Examples | |||||
| } | } | ||||
| Console.WriteLine($"Final loss: {loss_value}"); | Console.WriteLine($"Final loss: {loss_value}"); | ||||
| }); | }); | ||||
| return loss_value; | return loss_value; | ||||
| } | } | ||||