diff --git a/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs b/src/TensorFlowNET.Core/APIs/tf.config.cs
similarity index 62%
rename from src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs
rename to src/TensorFlowNET.Core/APIs/tf.config.cs
index 024190df..3c30ffb4 100644
--- a/src/TensorFlowNET.Core/Contrib/Learn/Preprocessing/VocabularyProcessor.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.config.cs
@@ -14,18 +14,19 @@
limitations under the License.
******************************************************************************/
-namespace Tensorflow.Contrib.Learn.Preprocessing
+using Tensorflow.Contexts;
+using Tensorflow.Framework;
+
+namespace Tensorflow
{
- public class VocabularyProcessor
+ public partial class tensorflow
{
- private int _max_document_length;
- private int _min_frequency;
-
- public VocabularyProcessor(int max_document_length,
- int min_frequency)
- {
- _max_document_length = max_document_length;
- _min_frequency = min_frequency;
- }
+ ///
+ /// Public API for tf.debugging namespace
+ /// https://www.tensorflow.org/api_docs/python/tf/debugging
+ /// More debugging instructions
+ /// https://developer.ibm.com/technologies/artificial-intelligence/tutorials/debug-tensorflow/
+ ///
+ public ConfigImpl config => new ConfigImpl();
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.debugging.cs b/src/TensorFlowNET.Core/APIs/tf.debugging.cs
index 1a9c7b46..9d129b20 100644
--- a/src/TensorFlowNET.Core/APIs/tf.debugging.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.debugging.cs
@@ -14,41 +14,18 @@
limitations under the License.
******************************************************************************/
+using Tensorflow.Debugging;
+
namespace Tensorflow
{
public partial class tensorflow
{
///
- /// Assert the condition `x == y` holds element-wise.
+ /// Public API for tf.debugging namespace
+ /// https://www.tensorflow.org/api_docs/python/tf/debugging
+ /// More debugging instructions
+ /// https://developer.ibm.com/technologies/artificial-intelligence/tutorials/debug-tensorflow/
///
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- ///
- public Tensor assert_equal(T1 t1,
- T2 t2,
- object[] data = null,
- string message = null,
- string name = null)
- => check_ops.assert_equal(t1,
- t2,
- data: data,
- message: message,
- name: name);
-
- public Tensor assert_greater_equal(Tensor x,
- Tensor y,
- object[] data = null,
- string message = null,
- string name = null)
- => check_ops.assert_greater_equal(x,
- y,
- data: data,
- message: message,
- name: name);
+ public DebugImpl debugging => new DebugImpl();
}
}
diff --git a/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
new file mode 100644
index 00000000..3626e9df
--- /dev/null
+++ b/src/TensorFlowNET.Core/Contexts/Context.AutoMode.cs
@@ -0,0 +1,85 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Diagnostics;
+using System.Linq;
+using Tensorflow.Eager;
+using static Tensorflow.Binding;
+using Google.Protobuf;
+
+namespace Tensorflow.Contexts
+{
+ ///
+ /// Environment in which eager operations execute.
+ ///
+ public sealed partial class Context
+ {
+ // [DebuggerStepThrough]
+ public T RunInAutoMode(Func graphAction, Func eagerAction, params Tensor[] tensors)
+ {
+ var shouldRunInEager = executing_eagerly()
+ && tensors.Count(x => x.IsEagerTensor) == tensors.Length;
+
+ if (shouldRunInEager)
+ return eagerAction();
+ else
+ {
+ if (executing_eagerly())
+ {
+ graph_mode();
+ var result = graphAction();
+ restore_mode();
+ return result;
+ }
+ else
+ {
+ return graphAction();
+ }
+ }
+ }
+
+ // [DebuggerStepThrough]
+ public Tensors RunInAutoMode2(Func graphAction,
+ Func eagerAction,
+ Action recordGradient,
+ Tensors tensors)
+ {
+ var shouldRunInEager = executing_eagerly()
+ && tensors.Count(x => x.IsEagerTensor) == tensors.Length;
+
+ if (shouldRunInEager)
+ return eagerAction();
+ else
+ {
+ if (executing_eagerly())
+ {
+ graph_mode();
+ var result = graphAction();
+ restore_mode();
+ return result;
+ }
+ else
+ {
+ var result = graphAction();
+ if (tf.Runner.MustRecordGradient())
+ recordGradient(result[0].op);
+ return result;
+ }
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Contexts/Context.Config.cs b/src/TensorFlowNET.Core/Contexts/Context.Config.cs
new file mode 100644
index 00000000..8f6be1cf
--- /dev/null
+++ b/src/TensorFlowNET.Core/Contexts/Context.Config.cs
@@ -0,0 +1,48 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Diagnostics;
+
+namespace Tensorflow.Contexts
+{
+ ///
+ /// Environment in which eager operations execute.
+ ///
+ public sealed partial class Context
+ {
+ ConfigProto _config;
+
+ ConfigProto config()
+ {
+ var config = new ConfigProto()
+ {
+ LogDevicePlacement = _log_device_placement,
+ GpuOptions = _compute_gpu_options()
+ };
+
+ return config;
+ }
+
+ GPUOptions _compute_gpu_options()
+ {
+ return new GPUOptions()
+ {
+
+ };
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Contexts/Context.Device.cs b/src/TensorFlowNET.Core/Contexts/Context.Device.cs
new file mode 100644
index 00000000..9485d3b4
--- /dev/null
+++ b/src/TensorFlowNET.Core/Contexts/Context.Device.cs
@@ -0,0 +1,42 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Diagnostics;
+using System.Linq;
+using Tensorflow.Eager;
+using static Tensorflow.Binding;
+using Google.Protobuf;
+
+namespace Tensorflow.Contexts
+{
+ ///
+ /// Environment in which eager operations execute.
+ ///
+ public sealed partial class Context
+ {
+ ContextDevicePlacementPolicy _device_policy;
+ bool _log_device_placement;
+
+ public void log_device_placement(bool enable)
+ {
+ if (_handle != null)
+ c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle);
+ _log_device_placement = enable;
+ // _thread_local_data.function_call_options = null;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs
index 226625b7..4e386f09 100644
--- a/src/TensorFlowNET.Core/Contexts/Context.cs
+++ b/src/TensorFlowNET.Core/Contexts/Context.cs
@@ -19,13 +19,14 @@ using System.Diagnostics;
using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;
+using Google.Protobuf;
namespace Tensorflow.Contexts
{
///
/// Environment in which eager operations execute.
///
- public sealed class Context : IDisposable
+ public sealed partial class Context : IDisposable
{
public const int GRAPH_MODE = 0;
public const int EAGER_MODE = 1;
@@ -37,14 +38,14 @@ namespace Tensorflow.Contexts
ContextSwitchStack context_switches;
public FunctionCallOptions FunctionCallOptions { get; }
- public SafeContextHandle Handle { get; }
+ SafeContextHandle _handle;
+ public SafeContextHandle Handle => _handle;
- public Context(ContextOptions opts, Status status)
+ public Context()
{
- Handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
- status.Check(true);
+ _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT;
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
- initialized = true;
+ initialized = false;
FunctionCallOptions = new FunctionCallOptions();
}
@@ -55,14 +56,25 @@ namespace Tensorflow.Contexts
{
if (initialized)
return;
+
+ _config = config();
+ var config_str = _config.ToByteArray();
+
+ using var opts = new ContextOptions();
+ using var status = new Status();
+ c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle);
+ status.Check(true);
+ c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy);
+ _handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
+ status.Check(true);
initialized = true;
}
public void start_step()
- => c_api.TFE_ContextStartStep(Handle);
+ => c_api.TFE_ContextStartStep(_handle);
public void end_step()
- => c_api.TFE_ContextEndStep(Handle);
+ => c_api.TFE_ContextEndStep(_handle);
///
/// Checks whether the current thread has eager execution enabled.
@@ -91,61 +103,7 @@ namespace Tensorflow.Contexts
context_switches.Pop();
}
- // [DebuggerStepThrough]
- public T RunInAutoMode(Func graphAction, Func eagerAction, params Tensor[] tensors)
- {
- var shouldRunInEager = executing_eagerly()
- && tensors.Count(x => x.IsEagerTensor) == tensors.Length;
-
- if (shouldRunInEager)
- return eagerAction();
- else
- {
- if (executing_eagerly())
- {
- graph_mode();
- var result = graphAction();
- restore_mode();
- return result;
- }
- else
- {
- return graphAction();
- }
- }
- }
-
- // [DebuggerStepThrough]
- public Tensors RunInAutoMode2(Func graphAction,
- Func eagerAction,
- Action recordGradient,
- Tensors tensors)
- {
- var shouldRunInEager = executing_eagerly()
- && tensors.Count(x => x.IsEagerTensor) == tensors.Length;
-
- if (shouldRunInEager)
- return eagerAction();
- else
- {
- if (executing_eagerly())
- {
- graph_mode();
- var result = graphAction();
- restore_mode();
- return result;
- }
- else
- {
- var result = graphAction();
- if (tf.Runner.MustRecordGradient())
- recordGradient(result[0].op);
- return result;
- }
- }
- }
-
public void Dispose()
- => Handle.Dispose();
+ => _handle.Dispose();
}
}
diff --git a/src/TensorFlowNET.Core/Contexts/ContextDevicePlacementPolicy.cs b/src/TensorFlowNET.Core/Contexts/ContextDevicePlacementPolicy.cs
new file mode 100644
index 00000000..96836a2f
--- /dev/null
+++ b/src/TensorFlowNET.Core/Contexts/ContextDevicePlacementPolicy.cs
@@ -0,0 +1,20 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Contexts
+{
+ public enum ContextDevicePlacementPolicy
+ {
+ // Running operations with input tensors on the wrong device will fail.
+ DEVICE_PLACEMENT_EXPLICIT = 0,
+ // Copy the tensor to the right device but log a warning.
+ DEVICE_PLACEMENT_WARN = 1,
+ // Silently copy the tensor, which has a performance cost since the operation
+ // will be blocked till the copy completes. This is the default placement
+ // policy.
+ DEVICE_PLACEMENT_SILENT = 2,
+ // Placement policy which silently copies int32 tensors but not other dtypes.
+ DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3,
+ }
+}
diff --git a/src/TensorFlowNET.Core/Contrib/Learn/Estimators/tensor_signature.cs b/src/TensorFlowNET.Core/Contrib/Learn/Estimators/tensor_signature.cs
deleted file mode 100644
index b2e7dd75..00000000
--- a/src/TensorFlowNET.Core/Contrib/Learn/Estimators/tensor_signature.cs
+++ /dev/null
@@ -1,39 +0,0 @@
-using NumSharp;
-using System.Linq;
-using Tensorflow.Framework;
-
-namespace Tensorflow.Contrib.Learn.Estimators
-{
- public static class tensor_signature
- {
- public static bool is_compatible_with(this Tensor self, Tensor other)
- {
- bool _shape_is_compatible_0dim(Shape _this, Shape _other)
- {
- var __other = tensor_shape.as_shape(_other);
- if (_this.Dimensions == null || __other.dims == null)
- return true;
-
- if (_this.NDim != __other.ndim)
- return false;
-
- foreach (var (x_dim, y_dim) in _this.Dimensions.Zip(__other.dims, (x_dim, y_dim) => (x_dim, y_dim)))
- {
- if (x_dim != y_dim)
- return false;
- }
-
- return true;
- }
-
- if (other.is_sparse())
- {
- return self.dtype.is_compatible_with(other.dtype);
- }
-
- return self.dtype.is_compatible_with(other.dtype) &&
- _shape_is_compatible_0dim(self.shape, other.shape) &&
- !self.is_sparse();
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Contrib/Train/HParams.cs b/src/TensorFlowNET.Core/Contrib/Train/HParams.cs
deleted file mode 100644
index 77eb5923..00000000
--- a/src/TensorFlowNET.Core/Contrib/Train/HParams.cs
+++ /dev/null
@@ -1,15 +0,0 @@
-namespace Tensorflow.Contrib.Train
-{
- ///
- /// Class to hold a set of hyperparameters as name-value pairs.
- ///
- public class HParams
- {
- public bool load_pretrained { get; set; }
-
- public HParams(bool load_pretrained)
- {
- this.load_pretrained = load_pretrained;
- }
- }
-}
diff --git a/src/TensorFlowNET.Core/Debugging/DebugImpl.cs b/src/TensorFlowNET.Core/Debugging/DebugImpl.cs
new file mode 100644
index 00000000..81627351
--- /dev/null
+++ b/src/TensorFlowNET.Core/Debugging/DebugImpl.cs
@@ -0,0 +1,50 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Binding;
+
+namespace Tensorflow.Debugging
+{
+ public class DebugImpl
+ {
+ ///
+ /// Set if device placements should be logged.
+ ///
+ /// Whether to enabled device placement logging.
+ public void set_log_device_placement(bool enabled)
+ => tf.Context.log_device_placement(enabled);
+
+ ///
+ /// Assert the condition `x == y` holds element-wise.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public Tensor assert_equal(T1 t1,
+ T2 t2,
+ object[] data = null,
+ string message = null,
+ string name = null)
+ => check_ops.assert_equal(t1,
+ t2,
+ data: data,
+ message: message,
+ name: name);
+
+ public Tensor assert_greater_equal(Tensor x,
+ Tensor y,
+ object[] data = null,
+ string message = null,
+ string name = null)
+ => check_ops.assert_greater_equal(x,
+ y,
+ data: data,
+ message: message,
+ name: name);
+ }
+}
diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
index f707804e..9d4706ca 100644
--- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs
+++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs
@@ -1,6 +1,7 @@
using Google.Protobuf;
using System;
using System.Runtime.InteropServices;
+using Tensorflow.Contexts;
using Tensorflow.Device;
using Tensorflow.Eager;
using Tensorflow.Util;
@@ -16,6 +17,22 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern SafeContextOptionsHandle TFE_NewContextOptions();
+ ///
+ /// Set the config in TF_ContextOptions.options.
+ /// config should be a serialized tensorflow.ConfigProto proto.
+ /// If config was not parsed successfully as a ConfigProto, record the
+ /// error information in *status.
+ ///
+ /// TFE_ContextOptions*
+ ///
+ /// size_t
+ /// SafeStatusHandle
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status);
+
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy);
+
///
/// Destroy an options object.
///
@@ -23,6 +40,16 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_DeleteContextOptions(IntPtr options);
+ ///
+ /// Configure device placement policy logging for the eager executor. Note this
+ /// policy is applied to any subsequent op executions.
+ ///
+ ///
+ ///
+ ///
+ [DllImport(TensorFlowLibName)]
+ public static extern void TFE_ContextSetLogDevicePlacement(SafeContextHandle ctx, bool enable, SafeStatusHandle status);
+
///
///
///
diff --git a/src/TensorFlowNET.Core/Framework/ConfigImpl.cs b/src/TensorFlowNET.Core/Framework/ConfigImpl.cs
new file mode 100644
index 00000000..7c774755
--- /dev/null
+++ b/src/TensorFlowNET.Core/Framework/ConfigImpl.cs
@@ -0,0 +1,11 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Framework
+{
+ public class ConfigImpl
+ {
+
+ }
+}
diff --git a/src/TensorFlowNET.Core/Framework/tensor_shape.cs b/src/TensorFlowNET.Core/Framework/tensor_shape.cs
index 96265758..35557701 100644
--- a/src/TensorFlowNET.Core/Framework/tensor_shape.cs
+++ b/src/TensorFlowNET.Core/Framework/tensor_shape.cs
@@ -2,7 +2,6 @@
using System;
using System.Linq;
using System.Text;
-using Tensorflow.Contrib.Learn.Estimators;
namespace Tensorflow.Framework
{
@@ -24,6 +23,36 @@ namespace Tensorflow.Framework
}
}
+ public static bool is_compatible_with(this Tensor self, Tensor other)
+ {
+ bool _shape_is_compatible_0dim(Shape _this, Shape _other)
+ {
+ var __other = tensor_shape.as_shape(_other);
+ if (_this.Dimensions == null || __other.dims == null)
+ return true;
+
+ if (_this.NDim != __other.ndim)
+ return false;
+
+ foreach (var (x_dim, y_dim) in _this.Dimensions.Zip(__other.dims, (x_dim, y_dim) => (x_dim, y_dim)))
+ {
+ if (x_dim != y_dim)
+ return false;
+ }
+
+ return true;
+ }
+
+ if (other.is_sparse())
+ {
+ return self.dtype.is_compatible_with(other.dtype);
+ }
+
+ return self.dtype.is_compatible_with(other.dtype) &&
+ _shape_is_compatible_0dim(self.shape, other.shape) &&
+ !self.is_sparse();
+ }
+
public static Dimension dimension_at_index(TensorShape shape, int index)
{
return shape.rank < 0 ?
diff --git a/src/TensorFlowNET.Core/Tensors/constant_op.cs b/src/TensorFlowNET.Core/Tensors/constant_op.cs
index 212b7ebd..45a809ca 100644
--- a/src/TensorFlowNET.Core/Tensors/constant_op.cs
+++ b/src/TensorFlowNET.Core/Tensors/constant_op.cs
@@ -122,6 +122,7 @@ namespace Tensorflow
private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid)
{
+ ctx.ensure_initialized();
// convert data type
if (dtype != TF_DataType.DtInvalid &&
value.GetType().Name != "NDArray" &&
diff --git a/src/TensorFlowNET.Core/tensorflow.cs b/src/TensorFlowNET.Core/tensorflow.cs
index 3d707215..60b22f71 100644
--- a/src/TensorFlowNET.Core/tensorflow.cs
+++ b/src/TensorFlowNET.Core/tensorflow.cs
@@ -53,7 +53,7 @@ namespace Tensorflow
.CreateLogger();
Status = new Status();
- Context = new Context(new ContextOptions(), Status);
+ Context = new Context();
OpDefLib = new OpDefLibrary();
ConstructThreadingObjects();
InitGradientEnvironment();
diff --git a/src/TensorFlowNET.Core/tensorflow.threading.cs b/src/TensorFlowNET.Core/tensorflow.threading.cs
index 7ab9e93d..c1be5d90 100644
--- a/src/TensorFlowNET.Core/tensorflow.threading.cs
+++ b/src/TensorFlowNET.Core/tensorflow.threading.cs
@@ -19,7 +19,7 @@ using System.Threading;
namespace Tensorflow
{
- public partial class tensorflow : ITensorFlowObject
+ public partial class tensorflow
{
protected ThreadLocal defaultSessionFactory;
diff --git a/src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs b/src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs
index b581e590..0189eaf9 100644
--- a/src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs
+++ b/src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs
@@ -16,6 +16,12 @@ namespace Tensorflow.Benchmark.Leak
[Benchmark]
public void Run()
{
+ tf.debugging.set_log_device_placement(true);
+
+ var a = tf.constant(3.0);
+ var b = tf.constant(2.0);
+ var c = tf.multiply(a, b);
+
int num = 50, width = 64, height = 64;
// if width = 128, height = 128, the exception occurs faster
@@ -47,7 +53,7 @@ namespace Tensorflow.Benchmark.Leak
optimizer: keras.optimizers.RMSprop(),
metrics: new[] { "accuracy" });
- model.fit(inputImages, outLables, batch_size: 1, epochs: 200);
+ model.fit(inputImages, outLables, batch_size: 32, epochs: 200);
}
}
}