Browse Source

Be able to set Keras session options #712

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
734fe29c02
10 changed files with 141 additions and 27 deletions
  1. +25
    -0
      docs/RELEASE.md
  2. +14
    -9
      src/TensorFlowNET.Core/Contexts/Context.Config.cs
  3. +51
    -0
      src/TensorFlowNET.Core/Contexts/Context.Device.cs
  4. +6
    -4
      src/TensorFlowNET.Core/Contexts/Context.cs
  5. +4
    -6
      src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
  6. +15
    -0
      src/TensorFlowNET.Core/Device/PhysicalDevice.cs
  7. +7
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  8. +16
    -0
      src/TensorFlowNET.Core/Framework/ConfigImpl.cs
  9. +1
    -3
      src/TensorFlowNET.Keras/Datasets/Cifar10.cs
  10. +2
    -5
      src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs

+ 25
- 0
docs/RELEASE.md View File

@@ -0,0 +1,25 @@
# Release Notes

**Thanks to our Contributors!**

This release contains contributions from many people at SciSharp as well as the external contributors.

**Release Date 01/09/2021**

### TensorFlow.Binding v0.32.0

* Fix input `dtype` for `MapDataset`.
* Fix `image_dataset_from_directory` function.
* Fix `tf.transpose`.
* Add `array_ops.where_v2`, `array_ops.select_v2`, `array_ops.softplus`.
* Add `dataset.dataset_cardinality`.

### TensorFlow.Keras v0.3.0

* Fix `weight` init value for `double` type in `compute_weighted_loss`.
* Add `MeanSquaredError `, `MeanAbsolutePercentageError `, `MeanAbsoluteError` and `MeanSquaredLogarithmicError` loss functions.
* `Sequential` model API works.
* Add `ShellProgressBar` to show training progress better.




+ 14
- 9
src/TensorFlowNET.Core/Contexts/Context.Config.cs View File

@@ -16,6 +16,7 @@


using System; using System;
using System.Diagnostics; using System.Diagnostics;
using System.Linq;


namespace Tensorflow.Contexts namespace Tensorflow.Contexts
{ {
@@ -24,24 +25,28 @@ namespace Tensorflow.Contexts
/// </summary> /// </summary>
public sealed partial class Context public sealed partial class Context
{ {
ConfigProto _config;

ConfigProto config()
public ConfigProto Config { get; set; } = new ConfigProto
{ {
var config = new ConfigProto()
GpuOptions = new GPUOptions
{ {
LogDevicePlacement = _log_device_placement,
GpuOptions = _compute_gpu_options()
};
}
};


return config;
ConfigProto MergeConfig()
{
Config.LogDevicePlacement = _log_device_placement;
// var gpu_options = _compute_gpu_options();
// Config.GpuOptions.AllowGrowth = gpu_options.AllowGrowth;
return Config;
} }


GPUOptions _compute_gpu_options() GPUOptions _compute_gpu_options()
{ {
// By default, TensorFlow maps nearly all of the GPU memory of all GPUs
// https://www.tensorflow.org/guide/gpu
return new GPUOptions() return new GPUOptions()
{ {
AllowGrowth = get_memory_growth("GPU")
}; };
} }
} }


+ 51
- 0
src/TensorFlowNET.Core/Contexts/Context.Device.cs View File

@@ -20,6 +20,8 @@ using System.Linq;
using Tensorflow.Eager; using Tensorflow.Eager;
using static Tensorflow.Binding; using static Tensorflow.Binding;
using Google.Protobuf; using Google.Protobuf;
using Tensorflow.Device;
using System.Collections.Generic;


namespace Tensorflow.Contexts namespace Tensorflow.Contexts
{ {
@@ -30,6 +32,7 @@ namespace Tensorflow.Contexts
{ {
ContextDevicePlacementPolicy _device_policy; ContextDevicePlacementPolicy _device_policy;
bool _log_device_placement; bool _log_device_placement;
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>();


public void log_device_placement(bool enable) public void log_device_placement(bool enable)
{ {
@@ -38,5 +41,53 @@ namespace Tensorflow.Contexts
_log_device_placement = enable; _log_device_placement = enable;
// _thread_local_data.function_call_options = null; // _thread_local_data.function_call_options = null;
} }

public bool get_memory_growth(string device_type)
{
foreach(var map in _memory_growth_map)
{
if (map.Key.DeviceType == device_type)
return map.Value;
}
return false;
}

public void set_memory_growth(PhysicalDevice device, bool enable)
{
_memory_growth_map[device] = enable;
}

public PhysicalDevice[] list_physical_devices(string device_type = null)
{
using var opts = c_api.TFE_NewContextOptions();
using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle);
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle);
tf.Status.Check(true);

int num_devices = c_api.TF_DeviceListCount(devices);
var results = new List<PhysicalDevice>();
for (int i = 0; i < num_devices; ++i)
{
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle));
tf.Status.Check(true);

if (dev_type.StartsWith("XLA"))
continue;

if (device_type == null || dev_type == device_type)
{
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle);
tf.Status.Check(true);

results.Add(new PhysicalDevice
{
DeviceName = dev_name,
DeviceType = dev_type
});
}
}

return results.ToArray();
}
} }
} }

+ 6
- 4
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -57,9 +57,9 @@ namespace Tensorflow.Contexts
if (initialized) if (initialized)
return; return;


_config = config();
var config_str = _config.ToByteArray();
Config = MergeConfig();
FunctionCallOptions.Config = Config;
var config_str = Config.ToByteArray();
using var opts = new ContextOptions(); using var opts = new ContextOptions();
using var status = new Status(); using var status = new Status();
c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle); c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle);
@@ -82,7 +82,9 @@ namespace Tensorflow.Contexts
/// <returns></returns> /// <returns></returns>
[DebuggerStepThrough] [DebuggerStepThrough]
public bool executing_eagerly() public bool executing_eagerly()
=> context_switches.Current().EagerMode;
{
return context_switches.Current().EagerMode;
}


public bool is_build_function() public bool is_build_function()
=> context_switches.Current().IsBuildingFunction; => context_switches.Current().IsBuildingFunction;


+ 4
- 6
src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs View File

@@ -2,19 +2,17 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Google.Protobuf; using Google.Protobuf;
using Google.Protobuf.Collections;
using static Tensorflow.Binding;


namespace Tensorflow.Contexts namespace Tensorflow.Contexts
{ {
public class FunctionCallOptions public class FunctionCallOptions
{ {
public ConfigProto Config { get; set; }

public string config_proto_serialized() public string config_proto_serialized()
{ {
var config = new ConfigProto
{
AllowSoftPlacement = true,
};
return config.ToByteString().ToStringUtf8();
return Config.ToByteString().ToStringUtf8();
} }
} }
} }

+ 15
- 0
src/TensorFlowNET.Core/Device/PhysicalDevice.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Device
{
public class PhysicalDevice
{
public string DeviceName { get; set; }
public string DeviceType { get; set; }

public override string ToString()
=> $"{DeviceType}: {DeviceName}";
}
}

+ 7
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -380,6 +380,13 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status); public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);


/// <summary>
/// Clears the internal caches in the TFE context. Useful when reseeding random ops.
/// </summary>
/// <param name="ctx">TFE_Context*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextClearCaches(SafeContextHandle ctx);

/// <summary> /// <summary>
/// ///
/// </summary> /// </summary>


+ 16
- 0
src/TensorFlowNET.Core/Framework/ConfigImpl.cs View File

@@ -1,11 +1,27 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Device;


namespace Tensorflow.Framework namespace Tensorflow.Framework
{ {
public class ConfigImpl public class ConfigImpl
{ {
/// <summary>
/// Return a list of physical devices visible to the host runtime.
/// </summary>
/// <param name="device_type">CPU, GPU, TPU</param>
/// <returns></returns>
public PhysicalDevice[] list_physical_devices(string device_type = null)
=> tf.Context.list_physical_devices(device_type: device_type);


public Experimental experimental => new Experimental();

public class Experimental
{
public void set_memory_growth(PhysicalDevice device, bool enable)
=> tf.Context.set_memory_growth(device, enable);
}
} }
} }

+ 1
- 3
src/TensorFlowNET.Keras/Datasets/Cifar10.cs View File

@@ -124,10 +124,8 @@ namespace Tensorflow.Keras.Datasets
string Download() string Download()
{ {
var dst = Path.Combine(Path.GetTempPath(), dest_folder); var dst = Path.Combine(Path.GetTempPath(), dest_folder);
Directory.CreateDirectory(dst);

Web.Download(origin_folder + file_name, dst, file_name); Web.Download(origin_folder + file_name, dst, file_name);
Compress.ExtractTGZ(Path.Combine(Path.GetTempPath(), file_name), dst);
Compress.ExtractTGZ(Path.Combine(dst, file_name), dst);


return Path.Combine(dst, "cifar-10-batches-py"); return Path.Combine(dst, "cifar-10-batches-py");
} }


+ 2
- 5
src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs View File

@@ -16,11 +16,8 @@ namespace Tensorflow.Benchmark.Leak
[Benchmark] [Benchmark]
public void Run() public void Run()
{ {
tf.debugging.set_log_device_placement(true);
var a = tf.constant(3.0);
var b = tf.constant(2.0);
var c = tf.multiply(a, b);
// tf.debugging.set_log_device_placement(true);
tf.Context.Config.GpuOptions.AllowGrowth = true;


int num = 50, width = 64, height = 64; int num = 50, width = 64, height = 64;
// if width = 128, height = 128, the exception occurs faster // if width = 128, height = 128, the exception occurs faster


Loading…
Cancel
Save