| @@ -14,18 +14,19 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| namespace Tensorflow.Contrib.Learn.Preprocessing | |||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Framework; | |||||
| namespace Tensorflow | |||||
| { | { | ||||
| public class VocabularyProcessor | |||||
| public partial class tensorflow | |||||
| { | { | ||||
| private int _max_document_length; | |||||
| private int _min_frequency; | |||||
| public VocabularyProcessor(int max_document_length, | |||||
| int min_frequency) | |||||
| { | |||||
| _max_document_length = max_document_length; | |||||
| _min_frequency = min_frequency; | |||||
| } | |||||
| /// <summary> | |||||
| /// Public API for tf.debugging namespace | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/debugging | |||||
| /// More debugging instructions | |||||
| /// https://developer.ibm.com/technologies/artificial-intelligence/tutorials/debug-tensorflow/ | |||||
| /// </summary> | |||||
| public ConfigImpl config => new ConfigImpl(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -14,41 +14,18 @@ | |||||
| limitations under the License. | limitations under the License. | ||||
| ******************************************************************************/ | ******************************************************************************/ | ||||
| using Tensorflow.Debugging; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow | public partial class tensorflow | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Assert the condition `x == y` holds element-wise. | |||||
| /// Public API for tf.debugging namespace | |||||
| /// https://www.tensorflow.org/api_docs/python/tf/debugging | |||||
| /// More debugging instructions | |||||
| /// https://developer.ibm.com/technologies/artificial-intelligence/tutorials/debug-tensorflow/ | |||||
| /// </summary> | /// </summary> | ||||
| /// <typeparam name="T1"></typeparam> | |||||
| /// <typeparam name="T2"></typeparam> | |||||
| /// <param name="t1"></param> | |||||
| /// <param name="t2"></param> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="message"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor assert_equal<T1, T2>(T1 t1, | |||||
| T2 t2, | |||||
| object[] data = null, | |||||
| string message = null, | |||||
| string name = null) | |||||
| => check_ops.assert_equal(t1, | |||||
| t2, | |||||
| data: data, | |||||
| message: message, | |||||
| name: name); | |||||
| public Tensor assert_greater_equal<T1, T2>(Tensor x, | |||||
| Tensor y, | |||||
| object[] data = null, | |||||
| string message = null, | |||||
| string name = null) | |||||
| => check_ops.assert_greater_equal(x, | |||||
| y, | |||||
| data: data, | |||||
| message: message, | |||||
| name: name); | |||||
| public DebugImpl debugging => new DebugImpl(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,85 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | |||||
| using Google.Protobuf; | |||||
| namespace Tensorflow.Contexts | |||||
| { | |||||
| /// <summary> | |||||
| /// Environment in which eager operations execute. | |||||
| /// </summary> | |||||
| public sealed partial class Context | |||||
| { | |||||
| // [DebuggerStepThrough] | |||||
| public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors) | |||||
| { | |||||
| var shouldRunInEager = executing_eagerly() | |||||
| && tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||||
| if (shouldRunInEager) | |||||
| return eagerAction(); | |||||
| else | |||||
| { | |||||
| if (executing_eagerly()) | |||||
| { | |||||
| graph_mode(); | |||||
| var result = graphAction(); | |||||
| restore_mode(); | |||||
| return result; | |||||
| } | |||||
| else | |||||
| { | |||||
| return graphAction(); | |||||
| } | |||||
| } | |||||
| } | |||||
| // [DebuggerStepThrough] | |||||
| public Tensors RunInAutoMode2(Func<Tensors> graphAction, | |||||
| Func<Tensors> eagerAction, | |||||
| Action<Operation> recordGradient, | |||||
| Tensors tensors) | |||||
| { | |||||
| var shouldRunInEager = executing_eagerly() | |||||
| && tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||||
| if (shouldRunInEager) | |||||
| return eagerAction(); | |||||
| else | |||||
| { | |||||
| if (executing_eagerly()) | |||||
| { | |||||
| graph_mode(); | |||||
| var result = graphAction(); | |||||
| restore_mode(); | |||||
| return result; | |||||
| } | |||||
| else | |||||
| { | |||||
| var result = graphAction(); | |||||
| if (tf.Runner.MustRecordGradient()) | |||||
| recordGradient(result[0].op); | |||||
| return result; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,48 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Diagnostics; | |||||
| namespace Tensorflow.Contexts | |||||
| { | |||||
| /// <summary> | |||||
| /// Environment in which eager operations execute. | |||||
| /// </summary> | |||||
| public sealed partial class Context | |||||
| { | |||||
| ConfigProto _config; | |||||
| ConfigProto config() | |||||
| { | |||||
| var config = new ConfigProto() | |||||
| { | |||||
| LogDevicePlacement = _log_device_placement, | |||||
| GpuOptions = _compute_gpu_options() | |||||
| }; | |||||
| return config; | |||||
| } | |||||
| GPUOptions _compute_gpu_options() | |||||
| { | |||||
| return new GPUOptions() | |||||
| { | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,42 @@ | |||||
| /***************************************************************************** | |||||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| you may not use this file except in compliance with the License. | |||||
| You may obtain a copy of the License at | |||||
| http://www.apache.org/licenses/LICENSE-2.0 | |||||
| Unless required by applicable law or agreed to in writing, software | |||||
| distributed under the License is distributed on an "AS IS" BASIS, | |||||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| See the License for the specific language governing permissions and | |||||
| limitations under the License. | |||||
| ******************************************************************************/ | |||||
| using System; | |||||
| using System.Diagnostics; | |||||
| using System.Linq; | |||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | |||||
| using Google.Protobuf; | |||||
| namespace Tensorflow.Contexts | |||||
| { | |||||
| /// <summary> | |||||
| /// Environment in which eager operations execute. | |||||
| /// </summary> | |||||
| public sealed partial class Context | |||||
| { | |||||
| ContextDevicePlacementPolicy _device_policy; | |||||
| bool _log_device_placement; | |||||
| public void log_device_placement(bool enable) | |||||
| { | |||||
| if (_handle != null) | |||||
| c_api.TFE_ContextSetLogDevicePlacement(_handle, enable, tf.Status.Handle); | |||||
| _log_device_placement = enable; | |||||
| // _thread_local_data.function_call_options = null; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -19,13 +19,14 @@ using System.Diagnostics; | |||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Google.Protobuf; | |||||
| namespace Tensorflow.Contexts | namespace Tensorflow.Contexts | ||||
| { | { | ||||
| /// <summary> | /// <summary> | ||||
| /// Environment in which eager operations execute. | /// Environment in which eager operations execute. | ||||
| /// </summary> | /// </summary> | ||||
| public sealed class Context : IDisposable | |||||
| public sealed partial class Context : IDisposable | |||||
| { | { | ||||
| public const int GRAPH_MODE = 0; | public const int GRAPH_MODE = 0; | ||||
| public const int EAGER_MODE = 1; | public const int EAGER_MODE = 1; | ||||
| @@ -37,14 +38,14 @@ namespace Tensorflow.Contexts | |||||
| ContextSwitchStack context_switches; | ContextSwitchStack context_switches; | ||||
| public FunctionCallOptions FunctionCallOptions { get; } | public FunctionCallOptions FunctionCallOptions { get; } | ||||
| public SafeContextHandle Handle { get; } | |||||
| SafeContextHandle _handle; | |||||
| public SafeContextHandle Handle => _handle; | |||||
| public Context(ContextOptions opts, Status status) | |||||
| public Context() | |||||
| { | { | ||||
| Handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | |||||
| status.Check(true); | |||||
| _device_policy = ContextDevicePlacementPolicy.DEVICE_PLACEMENT_SILENT; | |||||
| context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false); | ||||
| initialized = true; | |||||
| initialized = false; | |||||
| FunctionCallOptions = new FunctionCallOptions(); | FunctionCallOptions = new FunctionCallOptions(); | ||||
| } | } | ||||
| @@ -55,14 +56,25 @@ namespace Tensorflow.Contexts | |||||
| { | { | ||||
| if (initialized) | if (initialized) | ||||
| return; | return; | ||||
| _config = config(); | |||||
| var config_str = _config.ToByteArray(); | |||||
| using var opts = new ContextOptions(); | |||||
| using var status = new Status(); | |||||
| c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle); | |||||
| status.Check(true); | |||||
| c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts.Handle, _device_policy); | |||||
| _handle = c_api.TFE_NewContext(opts.Handle, status.Handle); | |||||
| status.Check(true); | |||||
| initialized = true; | initialized = true; | ||||
| } | } | ||||
| public void start_step() | public void start_step() | ||||
| => c_api.TFE_ContextStartStep(Handle); | |||||
| => c_api.TFE_ContextStartStep(_handle); | |||||
| public void end_step() | public void end_step() | ||||
| => c_api.TFE_ContextEndStep(Handle); | |||||
| => c_api.TFE_ContextEndStep(_handle); | |||||
| /// <summary> | /// <summary> | ||||
| /// Checks whether the current thread has eager execution enabled. | /// Checks whether the current thread has eager execution enabled. | ||||
| @@ -91,61 +103,7 @@ namespace Tensorflow.Contexts | |||||
| context_switches.Pop(); | context_switches.Pop(); | ||||
| } | } | ||||
| // [DebuggerStepThrough] | |||||
| public T RunInAutoMode<T>(Func<T> graphAction, Func<T> eagerAction, params Tensor[] tensors) | |||||
| { | |||||
| var shouldRunInEager = executing_eagerly() | |||||
| && tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||||
| if (shouldRunInEager) | |||||
| return eagerAction(); | |||||
| else | |||||
| { | |||||
| if (executing_eagerly()) | |||||
| { | |||||
| graph_mode(); | |||||
| var result = graphAction(); | |||||
| restore_mode(); | |||||
| return result; | |||||
| } | |||||
| else | |||||
| { | |||||
| return graphAction(); | |||||
| } | |||||
| } | |||||
| } | |||||
| // [DebuggerStepThrough] | |||||
| public Tensors RunInAutoMode2(Func<Tensors> graphAction, | |||||
| Func<Tensors> eagerAction, | |||||
| Action<Operation> recordGradient, | |||||
| Tensors tensors) | |||||
| { | |||||
| var shouldRunInEager = executing_eagerly() | |||||
| && tensors.Count(x => x.IsEagerTensor) == tensors.Length; | |||||
| if (shouldRunInEager) | |||||
| return eagerAction(); | |||||
| else | |||||
| { | |||||
| if (executing_eagerly()) | |||||
| { | |||||
| graph_mode(); | |||||
| var result = graphAction(); | |||||
| restore_mode(); | |||||
| return result; | |||||
| } | |||||
| else | |||||
| { | |||||
| var result = graphAction(); | |||||
| if (tf.Runner.MustRecordGradient()) | |||||
| recordGradient(result[0].op); | |||||
| return result; | |||||
| } | |||||
| } | |||||
| } | |||||
| public void Dispose() | public void Dispose() | ||||
| => Handle.Dispose(); | |||||
| => _handle.Dispose(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,20 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Contexts | |||||
| { | |||||
| public enum ContextDevicePlacementPolicy | |||||
| { | |||||
| // Running operations with input tensors on the wrong device will fail. | |||||
| DEVICE_PLACEMENT_EXPLICIT = 0, | |||||
| // Copy the tensor to the right device but log a warning. | |||||
| DEVICE_PLACEMENT_WARN = 1, | |||||
| // Silently copy the tensor, which has a performance cost since the operation | |||||
| // will be blocked till the copy completes. This is the default placement | |||||
| // policy. | |||||
| DEVICE_PLACEMENT_SILENT = 2, | |||||
| // Placement policy which silently copies int32 tensors but not other dtypes. | |||||
| DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, | |||||
| } | |||||
| } | |||||
| @@ -1,39 +0,0 @@ | |||||
| using NumSharp; | |||||
| using System.Linq; | |||||
| using Tensorflow.Framework; | |||||
| namespace Tensorflow.Contrib.Learn.Estimators | |||||
| { | |||||
| public static class tensor_signature | |||||
| { | |||||
| public static bool is_compatible_with(this Tensor self, Tensor other) | |||||
| { | |||||
| bool _shape_is_compatible_0dim(Shape _this, Shape _other) | |||||
| { | |||||
| var __other = tensor_shape.as_shape(_other); | |||||
| if (_this.Dimensions == null || __other.dims == null) | |||||
| return true; | |||||
| if (_this.NDim != __other.ndim) | |||||
| return false; | |||||
| foreach (var (x_dim, y_dim) in _this.Dimensions.Zip(__other.dims, (x_dim, y_dim) => (x_dim, y_dim))) | |||||
| { | |||||
| if (x_dim != y_dim) | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| if (other.is_sparse()) | |||||
| { | |||||
| return self.dtype.is_compatible_with(other.dtype); | |||||
| } | |||||
| return self.dtype.is_compatible_with(other.dtype) && | |||||
| _shape_is_compatible_0dim(self.shape, other.shape) && | |||||
| !self.is_sparse(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -1,15 +0,0 @@ | |||||
| namespace Tensorflow.Contrib.Train | |||||
| { | |||||
| /// <summary> | |||||
| /// Class to hold a set of hyperparameters as name-value pairs. | |||||
| /// </summary> | |||||
| public class HParams | |||||
| { | |||||
| public bool load_pretrained { get; set; } | |||||
| public HParams(bool load_pretrained) | |||||
| { | |||||
| this.load_pretrained = load_pretrained; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,50 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using static Tensorflow.Binding; | |||||
| namespace Tensorflow.Debugging | |||||
| { | |||||
| public class DebugImpl | |||||
| { | |||||
| /// <summary> | |||||
| /// Set if device placements should be logged. | |||||
| /// </summary> | |||||
| /// <param name="enabled"> Whether to enabled device placement logging.</param> | |||||
| public void set_log_device_placement(bool enabled) | |||||
| => tf.Context.log_device_placement(enabled); | |||||
| /// <summary> | |||||
| /// Assert the condition `x == y` holds element-wise. | |||||
| /// </summary> | |||||
| /// <typeparam name="T1"></typeparam> | |||||
| /// <typeparam name="T2"></typeparam> | |||||
| /// <param name="t1"></param> | |||||
| /// <param name="t2"></param> | |||||
| /// <param name="data"></param> | |||||
| /// <param name="message"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns></returns> | |||||
| public Tensor assert_equal<T1, T2>(T1 t1, | |||||
| T2 t2, | |||||
| object[] data = null, | |||||
| string message = null, | |||||
| string name = null) | |||||
| => check_ops.assert_equal(t1, | |||||
| t2, | |||||
| data: data, | |||||
| message: message, | |||||
| name: name); | |||||
| public Tensor assert_greater_equal<T1, T2>(Tensor x, | |||||
| Tensor y, | |||||
| object[] data = null, | |||||
| string message = null, | |||||
| string name = null) | |||||
| => check_ops.assert_greater_equal(x, | |||||
| y, | |||||
| data: data, | |||||
| message: message, | |||||
| name: name); | |||||
| } | |||||
| } | |||||
| @@ -1,6 +1,7 @@ | |||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using System; | using System; | ||||
| using System.Runtime.InteropServices; | using System.Runtime.InteropServices; | ||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Device; | using Tensorflow.Device; | ||||
| using Tensorflow.Eager; | using Tensorflow.Eager; | ||||
| using Tensorflow.Util; | using Tensorflow.Util; | ||||
| @@ -16,6 +17,22 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern SafeContextOptionsHandle TFE_NewContextOptions(); | public static extern SafeContextOptionsHandle TFE_NewContextOptions(); | ||||
| /// <summary> | |||||
| /// Set the config in TF_ContextOptions.options. | |||||
| /// config should be a serialized tensorflow.ConfigProto proto. | |||||
| /// If config was not parsed successfully as a ConfigProto, record the | |||||
| /// error information in *status. | |||||
| /// </summary> | |||||
| /// <param name="options">TFE_ContextOptions*</param> | |||||
| /// <param name="proto"></param> | |||||
| /// <param name="proto_len">size_t</param> | |||||
| /// <param name="status">SafeStatusHandle</param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | |||||
| /// <summary> | /// <summary> | ||||
| /// Destroy an options object. | /// Destroy an options object. | ||||
| /// </summary> | /// </summary> | ||||
| @@ -23,6 +40,16 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_DeleteContextOptions(IntPtr options); | public static extern void TFE_DeleteContextOptions(IntPtr options); | ||||
| /// <summary> | |||||
| /// Configure device placement policy logging for the eager executor. Note this | |||||
| /// policy is applied to any subsequent op executions. | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| /// <param name="enable"></param> | |||||
| /// <param name="status"></param> | |||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_ContextSetLogDevicePlacement(SafeContextHandle ctx, bool enable, SafeStatusHandle status); | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -0,0 +1,11 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Framework | |||||
| { | |||||
| public class ConfigImpl | |||||
| { | |||||
| } | |||||
| } | |||||
| @@ -2,7 +2,6 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Contrib.Learn.Estimators; | |||||
| namespace Tensorflow.Framework | namespace Tensorflow.Framework | ||||
| { | { | ||||
| @@ -24,6 +23,36 @@ namespace Tensorflow.Framework | |||||
| } | } | ||||
| } | } | ||||
| public static bool is_compatible_with(this Tensor self, Tensor other) | |||||
| { | |||||
| bool _shape_is_compatible_0dim(Shape _this, Shape _other) | |||||
| { | |||||
| var __other = tensor_shape.as_shape(_other); | |||||
| if (_this.Dimensions == null || __other.dims == null) | |||||
| return true; | |||||
| if (_this.NDim != __other.ndim) | |||||
| return false; | |||||
| foreach (var (x_dim, y_dim) in _this.Dimensions.Zip(__other.dims, (x_dim, y_dim) => (x_dim, y_dim))) | |||||
| { | |||||
| if (x_dim != y_dim) | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| if (other.is_sparse()) | |||||
| { | |||||
| return self.dtype.is_compatible_with(other.dtype); | |||||
| } | |||||
| return self.dtype.is_compatible_with(other.dtype) && | |||||
| _shape_is_compatible_0dim(self.shape, other.shape) && | |||||
| !self.is_sparse(); | |||||
| } | |||||
| public static Dimension dimension_at_index(TensorShape shape, int index) | public static Dimension dimension_at_index(TensorShape shape, int index) | ||||
| { | { | ||||
| return shape.rank < 0 ? | return shape.rank < 0 ? | ||||
| @@ -122,6 +122,7 @@ namespace Tensorflow | |||||
| private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) | private static EagerTensor convert_to_eager_tensor(object value, Context ctx, TF_DataType dtype = TF_DataType.DtInvalid) | ||||
| { | { | ||||
| ctx.ensure_initialized(); | |||||
| // convert data type | // convert data type | ||||
| if (dtype != TF_DataType.DtInvalid && | if (dtype != TF_DataType.DtInvalid && | ||||
| value.GetType().Name != "NDArray" && | value.GetType().Name != "NDArray" && | ||||
| @@ -53,7 +53,7 @@ namespace Tensorflow | |||||
| .CreateLogger(); | .CreateLogger(); | ||||
| Status = new Status(); | Status = new Status(); | ||||
| Context = new Context(new ContextOptions(), Status); | |||||
| Context = new Context(); | |||||
| OpDefLib = new OpDefLibrary(); | OpDefLib = new OpDefLibrary(); | ||||
| ConstructThreadingObjects(); | ConstructThreadingObjects(); | ||||
| InitGradientEnvironment(); | InitGradientEnvironment(); | ||||
| @@ -19,7 +19,7 @@ using System.Threading; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class tensorflow : ITensorFlowObject | |||||
| public partial class tensorflow | |||||
| { | { | ||||
| protected ThreadLocal<Session> defaultSessionFactory; | protected ThreadLocal<Session> defaultSessionFactory; | ||||
| @@ -16,6 +16,12 @@ namespace Tensorflow.Benchmark.Leak | |||||
| [Benchmark] | [Benchmark] | ||||
| public void Run() | public void Run() | ||||
| { | { | ||||
| tf.debugging.set_log_device_placement(true); | |||||
| var a = tf.constant(3.0); | |||||
| var b = tf.constant(2.0); | |||||
| var c = tf.multiply(a, b); | |||||
| int num = 50, width = 64, height = 64; | int num = 50, width = 64, height = 64; | ||||
| // if width = 128, height = 128, the exception occurs faster | // if width = 128, height = 128, the exception occurs faster | ||||
| @@ -47,7 +53,7 @@ namespace Tensorflow.Benchmark.Leak | |||||
| optimizer: keras.optimizers.RMSprop(), | optimizer: keras.optimizers.RMSprop(), | ||||
| metrics: new[] { "accuracy" }); | metrics: new[] { "accuracy" }); | ||||
| model.fit(inputImages, outLables, batch_size: 1, epochs: 200); | |||||
| model.fit(inputImages, outLables, batch_size: 32, epochs: 200); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||