| @@ -1,8 +0,0 @@ | |||||
| using System; | |||||
| namespace TensorFlowHub | |||||
| { | |||||
| public class Class1 | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public abstract class DataSetBase : IDataSet | |||||
| { | |||||
| public NDArray Data { get; protected set; } | |||||
| public NDArray Labels { get; protected set; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,46 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class Datasets<TDataSet> where TDataSet : IDataSet | |||||
| { | |||||
| public TDataSet Train { get; private set; } | |||||
| public TDataSet Validation { get; private set; } | |||||
| public TDataSet Test { get; private set; } | |||||
| public Datasets(TDataSet train, TDataSet validation, TDataSet test) | |||||
| { | |||||
| Train = train; | |||||
| Validation = validation; | |||||
| Test = test; | |||||
| } | |||||
| public (NDArray, NDArray) Randomize(NDArray x, NDArray y) | |||||
| { | |||||
| var perm = np.random.permutation(y.shape[0]); | |||||
| np.random.shuffle(perm); | |||||
| return (x[perm], y[perm]); | |||||
| } | |||||
| /// <summary> | |||||
| /// selects a few number of images determined by the batch_size variable (if you don't know why, read about Stochastic Gradient Method) | |||||
| /// </summary> | |||||
| /// <param name="x"></param> | |||||
| /// <param name="y"></param> | |||||
| /// <param name="start"></param> | |||||
| /// <param name="end"></param> | |||||
| /// <returns></returns> | |||||
| public (NDArray, NDArray) GetNextBatch(NDArray x, NDArray y, int start, int end) | |||||
| { | |||||
| var slice = new Slice(start, end); | |||||
| var x_batch = x[slice]; | |||||
| var y_batch = y[slice]; | |||||
| return (x_batch, y_batch); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,13 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public interface IDataSet | |||||
| { | |||||
| NDArray Data { get; } | |||||
| NDArray Labels { get; } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Threading.Tasks; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public interface IModelLoader<TDataSet> | |||||
| where TDataSet : IDataSet | |||||
| { | |||||
| Task<Datasets<TDataSet>> LoadAsync(ModelLoadSetting setting); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| using Tensorflow; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class MnistDataSet : DataSetBase | |||||
| { | |||||
| public int NumOfExamples { get; private set; } | |||||
| public int EpochsCompleted { get; private set; } | |||||
| public int IndexInEpoch { get; private set; } | |||||
| public MnistDataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape) | |||||
| { | |||||
| EpochsCompleted = 0; | |||||
| IndexInEpoch = 0; | |||||
| NumOfExamples = images.shape[0]; | |||||
| images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]); | |||||
| images.astype(dtype.as_numpy_datatype()); | |||||
| images = np.multiply(images, 1.0f / 255.0f); | |||||
| Data = images; | |||||
| labels.astype(dtype.as_numpy_datatype()); | |||||
| Labels = labels; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,16 @@ | |||||
| using System; | |||||
| using System.Threading.Tasks; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class MnistModelLoader : IModelLoader<MnistDataSet> | |||||
| { | |||||
| public Task<Datasets<MnistDataSet>> LoadAsync(ModelLoadSetting setting) | |||||
| { | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,19 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using NumSharp; | |||||
| namespace Tensorflow.Hub | |||||
| { | |||||
| public class ModelLoadSetting | |||||
| { | |||||
| public string TrainDir { get; set; } | |||||
| public bool OneHot { get; set; } | |||||
| public TF_DataType DtType { get; set; } = TF_DataType.TF_FLOAT; | |||||
| public bool ReShape { get; set; } | |||||
| public int ValidationSize { get; set; } = 5000; | |||||
| public int? TrainSize { get; set; } | |||||
| public int? TestSize { get; set; } | |||||
| public string SourceUrl { get; set; } | |||||
| } | |||||
| } | |||||
| @@ -1,7 +1,13 @@ | |||||
| <Project Sdk="Microsoft.NET.Sdk"> | <Project Sdk="Microsoft.NET.Sdk"> | ||||
| <PropertyGroup> | <PropertyGroup> | ||||
| <AssemblyName>TensorFlow.Net.Hub</AssemblyName> | |||||
| <RootNamespace>Tensorflow.Hub</RootNamespace> | |||||
| <TargetFramework>netstandard2.0</TargetFramework> | <TargetFramework>netstandard2.0</TargetFramework> | ||||
| </PropertyGroup> | </PropertyGroup> | ||||
| <ItemGroup> | |||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | |||||
| </ItemGroup> | |||||
| <ItemGroup> | |||||
| <PackageReference Include="NumSharp" Version="0.10.4" /> | |||||
| </ItemGroup> | |||||
| </Project> | </Project> | ||||