|
|
|
@@ -22,8 +22,9 @@ namespace TensorFlowNET.Examples |
|
|
|
|
|
|
|
private float learning_rate = 0.01f; |
|
|
|
public int TrainingEpochs = 10; |
|
|
|
public int DataSize = 5000; |
|
|
|
public int TestSize = 5000; |
|
|
|
public int? TrainSize = null; |
|
|
|
public int ValidationSize = 5000; |
|
|
|
public int? TestSize = null; |
|
|
|
public int BatchSize = 100; |
|
|
|
private int display_step = 1; |
|
|
|
|
|
|
|
@@ -98,7 +99,7 @@ namespace TensorFlowNET.Examples |
|
|
|
|
|
|
|
public void PrepareData() |
|
|
|
{ |
|
|
|
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, validation_size: DataSize, test_size: TestSize); |
|
|
|
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size: ValidationSize, test_size: TestSize); |
|
|
|
} |
|
|
|
|
|
|
|
public void SaveModel(Session sess) |
|
|
|
@@ -141,7 +142,7 @@ namespace TensorFlowNET.Examples |
|
|
|
if (results.argmax() == (batch_ys[0] as NDArray).argmax()) |
|
|
|
print("predicted OK!"); |
|
|
|
else |
|
|
|
throw new ValueError("predict error, maybe 90% accuracy"); |
|
|
|
throw new ValueError("predict error, should be 90% accuracy"); |
|
|
|
}); |
|
|
|
} |
|
|
|
} |
|
|
|
|