Browse Source

Be able to set Keras session options #712

tags/keras_v0.3.0
Oceania2018 4 years ago
parent
commit
734fe29c02
10 changed files with 141 additions and 27 deletions
  1. +25
    -0
      docs/RELEASE.md
  2. +14
    -9
      src/TensorFlowNET.Core/Contexts/Context.Config.cs
  3. +51
    -0
      src/TensorFlowNET.Core/Contexts/Context.Device.cs
  4. +6
    -4
      src/TensorFlowNET.Core/Contexts/Context.cs
  5. +4
    -6
      src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs
  6. +15
    -0
      src/TensorFlowNET.Core/Device/PhysicalDevice.cs
  7. +7
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  8. +16
    -0
      src/TensorFlowNET.Core/Framework/ConfigImpl.cs
  9. +1
    -3
      src/TensorFlowNET.Keras/Datasets/Cifar10.cs
  10. +2
    -5
      src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs

+ 25
- 0
docs/RELEASE.md View File

@@ -0,0 +1,25 @@
# Release Notes

**Thanks to our Contributors!**

This release contains contributions from many people at SciSharp as well as the external contributors.

**Release Date 01/09/2021**

### TensorFlow.Binding v0.32.0

* Fix input `dtype` for `MapDataset`.
* Fix `image_dataset_from_directory` function.
* Fix `tf.transpose`.
* Add `array_ops.where_v2`, `array_ops.select_v2`, `array_ops.softplus`.
* Add `dataset.dataset_cardinality`.

### TensorFlow.Keras v0.3.0

* Fix `weight` init value for `double` type in `compute_weighted_loss`.
* Add `MeanSquaredError `, `MeanAbsolutePercentageError `, `MeanAbsoluteError` and `MeanSquaredLogarithmicError` loss functions.
* `Sequential` model API works.
* Add `ShellProgressBar` to show training progress better.




+ 14
- 9
src/TensorFlowNET.Core/Contexts/Context.Config.cs View File

@@ -16,6 +16,7 @@

using System;
using System.Diagnostics;
using System.Linq;

namespace Tensorflow.Contexts
{
@@ -24,24 +25,28 @@ namespace Tensorflow.Contexts
/// </summary>
public sealed partial class Context
{
ConfigProto _config;

ConfigProto config()
public ConfigProto Config { get; set; } = new ConfigProto
{
var config = new ConfigProto()
GpuOptions = new GPUOptions
{
LogDevicePlacement = _log_device_placement,
GpuOptions = _compute_gpu_options()
};
}
};

return config;
ConfigProto MergeConfig()
{
Config.LogDevicePlacement = _log_device_placement;
// var gpu_options = _compute_gpu_options();
// Config.GpuOptions.AllowGrowth = gpu_options.AllowGrowth;
return Config;
}

GPUOptions _compute_gpu_options()
{
// By default, TensorFlow maps nearly all of the GPU memory of all GPUs
// https://www.tensorflow.org/guide/gpu
return new GPUOptions()
{
AllowGrowth = get_memory_growth("GPU")
};
}
}


+ 51
- 0
src/TensorFlowNET.Core/Contexts/Context.Device.cs View File

@@ -20,6 +20,8 @@ using System.Linq;
using Tensorflow.Eager;
using static Tensorflow.Binding;
using Google.Protobuf;
using Tensorflow.Device;
using System.Collections.Generic;

namespace Tensorflow.Contexts
{
@@ -30,6 +32,7 @@ namespace Tensorflow.Contexts
{
ContextDevicePlacementPolicy _device_policy;
bool _log_device_placement;
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>();

public void log_device_placement(bool enable)
{
@@ -38,5 +41,53 @@ namespace Tensorflow.Contexts
_log_device_placement = enable;
// _thread_local_data.function_call_options = null;
}

public bool get_memory_growth(string device_type)
{
foreach(var map in _memory_growth_map)
{
if (map.Key.DeviceType == device_type)
return map.Value;
}
return false;
}

public void set_memory_growth(PhysicalDevice device, bool enable)
{
_memory_growth_map[device] = enable;
}

public PhysicalDevice[] list_physical_devices(string device_type = null)
{
using var opts = c_api.TFE_NewContextOptions();
using var ctx = c_api.TFE_NewContext(opts, tf.Status.Handle);
using var devices = c_api.TFE_ContextListDevices(ctx, tf.Status.Handle);
tf.Status.Check(true);

int num_devices = c_api.TF_DeviceListCount(devices);
var results = new List<PhysicalDevice>();
for (int i = 0; i < num_devices; ++i)
{
var dev_type = c_api.StringPiece(c_api.TF_DeviceListType(devices, i, tf.Status.Handle));
tf.Status.Check(true);

if (dev_type.StartsWith("XLA"))
continue;

if (device_type == null || dev_type == device_type)
{
var dev_name = c_api.TF_DeviceListName(devices, i, tf.Status.Handle);
tf.Status.Check(true);

results.Add(new PhysicalDevice
{
DeviceName = dev_name,
DeviceType = dev_type
});
}
}

return results.ToArray();
}
}
}

+ 6
- 4
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -57,9 +57,9 @@ namespace Tensorflow.Contexts
if (initialized)
return;

_config = config();
var config_str = _config.ToByteArray();
Config = MergeConfig();
FunctionCallOptions.Config = Config;
var config_str = Config.ToByteArray();
using var opts = new ContextOptions();
using var status = new Status();
c_api.TFE_ContextOptionsSetConfig(opts.Handle, config_str, (ulong)config_str.Length, status.Handle);
@@ -82,7 +82,9 @@ namespace Tensorflow.Contexts
/// <returns></returns>
[DebuggerStepThrough]
public bool executing_eagerly()
=> context_switches.Current().EagerMode;
{
return context_switches.Current().EagerMode;
}

public bool is_build_function()
=> context_switches.Current().IsBuildingFunction;


+ 4
- 6
src/TensorFlowNET.Core/Contexts/FunctionCallOptions.cs View File

@@ -2,19 +2,17 @@
using System.Collections.Generic;
using System.Text;
using Google.Protobuf;
using Google.Protobuf.Collections;
using static Tensorflow.Binding;

namespace Tensorflow.Contexts
{
public class FunctionCallOptions
{
public ConfigProto Config { get; set; }

public string config_proto_serialized()
{
var config = new ConfigProto
{
AllowSoftPlacement = true,
};
return config.ToByteString().ToStringUtf8();
return Config.ToByteString().ToStringUtf8();
}
}
}

+ 15
- 0
src/TensorFlowNET.Core/Device/PhysicalDevice.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Device
{
public class PhysicalDevice
{
public string DeviceName { get; set; }
public string DeviceType { get; set; }

public override string ToString()
=> $"{DeviceType}: {DeviceName}";
}
}

+ 7
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -380,6 +380,13 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern SafeDeviceListHandle TFE_ContextListDevices(SafeContextHandle ctx, SafeStatusHandle status);

/// <summary>
/// Clears the internal caches in the TFE context. Useful when reseeding random ops.
/// </summary>
/// <param name="ctx">TFE_Context*</param>
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextClearCaches(SafeContextHandle ctx);

/// <summary>
///
/// </summary>


+ 16
- 0
src/TensorFlowNET.Core/Framework/ConfigImpl.cs View File

@@ -1,11 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;
using static Tensorflow.Binding;
using Tensorflow.Device;

namespace Tensorflow.Framework
{
public class ConfigImpl
{
/// <summary>
/// Return a list of physical devices visible to the host runtime.
/// </summary>
/// <param name="device_type">CPU, GPU, TPU</param>
/// <returns></returns>
public PhysicalDevice[] list_physical_devices(string device_type = null)
=> tf.Context.list_physical_devices(device_type: device_type);

public Experimental experimental => new Experimental();

public class Experimental
{
public void set_memory_growth(PhysicalDevice device, bool enable)
=> tf.Context.set_memory_growth(device, enable);
}
}
}

+ 1
- 3
src/TensorFlowNET.Keras/Datasets/Cifar10.cs View File

@@ -124,10 +124,8 @@ namespace Tensorflow.Keras.Datasets
string Download()
{
var dst = Path.Combine(Path.GetTempPath(), dest_folder);
Directory.CreateDirectory(dst);

Web.Download(origin_folder + file_name, dst, file_name);
Compress.ExtractTGZ(Path.Combine(Path.GetTempPath(), file_name), dst);
Compress.ExtractTGZ(Path.Combine(dst, file_name), dst);

return Path.Combine(dst, "cifar-10-batches-py");
}


+ 2
- 5
src/TensorFlowNet.Benchmarks/Leak/GpuLeakByCNN.cs View File

@@ -16,11 +16,8 @@ namespace Tensorflow.Benchmark.Leak
[Benchmark]
public void Run()
{
tf.debugging.set_log_device_placement(true);
var a = tf.constant(3.0);
var b = tf.constant(2.0);
var c = tf.multiply(a, b);
// tf.debugging.set_log_device_placement(true);
tf.Context.Config.GpuOptions.AllowGrowth = true;

int num = 50, width = 64, height = 64;
// if width = 128, height = 128, the exception occurs faster


Loading…
Cancel
Save