diff --git a/.gitignore b/.gitignore
index e4aba715..eee1dc7b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -62,7 +62,6 @@ StyleCopReport.xml
*_p.c
*_i.h
*.ilk
-*.meta
*.obj
*.iobj
*.pch
diff --git a/README.md b/README.md
index 3ac68eed..a2a205fd 100644
--- a/README.md
+++ b/README.md
@@ -1,5 +1,4 @@
# TensorFlow.NET
-TensorFlow.NET provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
TensorFlow.NET (TF.NET) provides a .NET Standard binding for [TensorFlow](https://www.tensorflow.org/). It aims to implement the complete Tensorflow API in CSharp which allows .NET developers to develop, train and deploy Machine Learning models with the cross-platform .NET Standard framework.
[](https://gitter.im/sci-sharp/community)
diff --git a/TensorFlow.NET.sln b/TensorFlow.NET.sln
index b96f8203..bef96f5e 100644
--- a/TensorFlow.NET.sln
+++ b/TensorFlow.NET.sln
@@ -9,6 +9,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Examples", "t
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "TensorFlowNET.Core", "src\TensorFlowNET.Core\TensorFlowNET.Core.csproj", "{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}"
EndProject
+Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "NumSharp.Core", "..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj", "{DE97EAD7-B92C-4112-9690-91C40A97179E}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -27,6 +29,10 @@ Global
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Debug|Any CPU.Build.0 = Debug|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.ActiveCfg = Release|Any CPU
{FD682AC0-7B2D-45D3-8B0D-C6D678B04144}.Release|Any CPU.Build.0 = Release|Any CPU
+ {DE97EAD7-B92C-4112-9690-91C40A97179E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {DE97EAD7-B92C-4112-9690-91C40A97179E}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {DE97EAD7-B92C-4112-9690-91C40A97179E}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {DE97EAD7-B92C-4112-9690-91C40A97179E}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
diff --git a/graph/README.md b/graph/README.md
new file mode 100644
index 00000000..e69de29b
diff --git a/graph/cond_test.meta b/graph/cond_test.meta
new file mode 100644
index 00000000..2110d577
Binary files /dev/null and b/graph/cond_test.meta differ
diff --git a/graph/kmeans.meta b/graph/kmeans.meta
new file mode 100644
index 00000000..0ad4f03f
Binary files /dev/null and b/graph/kmeans.meta differ
diff --git a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
index 11a89af1..270b9056 100644
--- a/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
+++ b/src/TensorFlowNET.Core/Framework/op_def_registry.py.cs
@@ -20,32 +20,9 @@ namespace Tensorflow
foreach (var op_def in op_list.Op)
_registered_ops[op_def.Name] = op_def;
-
- if (!_registered_ops.ContainsKey("NearestNeighbors"))
- _registered_ops["NearestNeighbors"] = op_NearestNeighbors();
}
return _registered_ops;
}
-
- ///
- /// Doesn't work because the op can't be found on binary
- ///
- ///
- private static OpDef op_NearestNeighbors()
- {
- var def = new OpDef
- {
- Name = "NearestNeighbors"
- };
-
- def.InputArg.Add(new ArgDef { Name = "points", Type = DataType.DtFloat });
- def.InputArg.Add(new ArgDef { Name = "centers", Type = DataType.DtFloat });
- def.InputArg.Add(new ArgDef { Name = "k", Type = DataType.DtInt64 });
- def.OutputArg.Add(new ArgDef { Name = "nearest_center_indices", Type = DataType.DtInt64 });
- def.OutputArg.Add(new ArgDef { Name = "nearest_center_distances", Type = DataType.DtFloat });
-
- return def;
- }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
index 254df0cf..40238dce 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/CondContext.cs
@@ -335,7 +335,7 @@ namespace Tensorflow.Operations
ret.Enter();
foreach (var nested_def in proto.NestedContexts)
- throw new NotImplementedException("");
+ from_control_flow_context_def(nested_def, import_scope: import_scope);
ret.Exit();
return ret;
}
diff --git a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
index 48a519db..a36602f7 100644
--- a/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
+++ b/src/TensorFlowNET.Core/Operations/ControlFlows/ControlFlowContext.cs
@@ -3,7 +3,8 @@ using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Operations.ControlFlows;
-
+using static Tensorflow.ControlFlowContextDef;
+
namespace Tensorflow.Operations
{
///
@@ -184,6 +185,23 @@ namespace Tensorflow.Operations
return null;
}
+ ///
+ /// Deserializes `context_def` into the appropriate ControlFlowContext.
+ ///
+ /// ControlFlowContextDef proto
+ /// Name scope to add
+ /// A ControlFlowContext subclass
+ protected ControlFlowContext from_control_flow_context_def(ControlFlowContextDef context_def, string import_scope = "")
+ {
+ switch (context_def.CtxtCase)
+ {
+ case CtxtOneofCase.CondCtxt:
+ return new CondContext().from_proto(context_def.CondCtxt, import_scope: import_scope);
+ }
+
+ throw new NotImplementedException($"Unknown ControlFlowContextDef field: {context_def.CtxtCase}");
+ }
+
public object to_proto()
{
throw new NotImplementedException();
diff --git a/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
new file mode 100644
index 00000000..013c193c
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/Operation.Implicit.cs
@@ -0,0 +1,34 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow
+{
+ ///
+ /// Convert to other datatype implicitly
+ ///
+ public partial class Operation
+ {
+ public static implicit operator Operation(IntPtr handle) => new Operation(handle);
+ public static implicit operator IntPtr(Operation op) => op._handle;
+ public static implicit operator Tensor(Operation op) => op.output;
+
+ public override string ToString()
+ {
+ return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
+ }
+
+ public override bool Equals(object obj)
+ {
+ switch (obj)
+ {
+ case IntPtr val:
+ return val == _handle;
+ case Operation val:
+ return val._handle == _handle;
+ }
+
+ return base.Equals(obj);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/Operation.cs b/src/TensorFlowNET.Core/Operations/Operation.cs
index 5915c216..bbb62be2 100644
--- a/src/TensorFlowNET.Core/Operations/Operation.cs
+++ b/src/TensorFlowNET.Core/Operations/Operation.cs
@@ -248,27 +248,6 @@ namespace Tensorflow
s.Check();
return NodeDef.Parser.ParseFrom(buffer);
}
- }
-
- public override string ToString()
- {
- return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{name}' type={OpType}";
- }
-
- public static implicit operator Operation(IntPtr handle) => new Operation(handle);
- public static implicit operator IntPtr(Operation op) => op._handle;
-
- public override bool Equals(object obj)
- {
- switch (obj)
- {
- case IntPtr val:
- return val == _handle;
- case Operation val:
- return val._handle == _handle;
- }
-
- return base.Equals(obj);
}
///
diff --git a/src/TensorFlowNET.Core/Python.cs b/src/TensorFlowNET.Core/Python.cs
index d216dae3..e183ef43 100644
--- a/src/TensorFlowNET.Core/Python.cs
+++ b/src/TensorFlowNET.Core/Python.cs
@@ -27,6 +27,11 @@ namespace Tensorflow
return Enumerable.Range(0, end);
}
+ protected IEnumerable range(int start, int end)
+ {
+ return Enumerable.Range(start, end);
+ }
+
public static T New(object args) where T : IPyClass
{
var instance = Activator.CreateInstance();
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index 06862cdb..3c0c7486 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -204,6 +204,12 @@ namespace Tensorflow
switch (tensor.dtype)
{
+ case TF_DataType.TF_BOOL:
+ var bools = new bool[tensor.size];
+ for (ulong i = 0; i < tensor.size; i++)
+ bools[i] = *(bool*)(offset + (int)(tensor.itemsize * i));
+ nd = np.array(bools).reshape(ndims);
+ break;
case TF_DataType.TF_STRING:
var bytes = tensor.Data();
// wired, don't know why we have to start from offset 9.
@@ -211,12 +217,6 @@ namespace Tensorflow
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
break;
- case TF_DataType.TF_UINT8:
- var _bytes = new byte[tensor.size];
- for (ulong i = 0; i < tensor.size; i++)
- _bytes[i] = *(byte*)(offset + (int)(tensor.itemsize * i));
- nd = np.array(_bytes).reshape(ndims);
- break;
case TF_DataType.TF_INT16:
var shorts = new short[tensor.size];
for (ulong i = 0; i < tensor.size; i++)
diff --git a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
index e06f9f2e..e18eebc0 100644
--- a/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
+++ b/src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs
@@ -10,9 +10,9 @@ namespace Tensorflow
///
public class _ElementFetchMapper : _FetchMapper
{
- private Func, object> _contraction_fn;
+ private Func, object> _contraction_fn;
- public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn)
+ public _ElementFetchMapper(object[] fetches, Func, object> contraction_fn)
{
var g = ops.get_default_graph();
ITensorOrOperation el = null;
@@ -31,7 +31,7 @@ namespace Tensorflow
///
///
///
- public override NDArray build_results(List