diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index c7314461..9280179c 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -54,8 +54,11 @@ public static class SaveUtilV1 var g = to_graph.as_default(); var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - tf.device("/cpu:0"); - var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => + { + // TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. + return constant_op.constant(graph_proto.ToByteArray()); + }); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); g.Exit(); return (named_saveable_objects, registered_savers); @@ -66,8 +69,10 @@ public static class SaveUtilV1 { var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, object_map, call_with_mapped_captures, saveables_cache); - tf.device("/cpu:0"); - var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); + var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant(graph_proto.ToString()); + }); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); return (named_saveable_objects, registered_savers); } diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index c736c164..30d45e82 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -59,8 +59,10 @@ public class TrackableSaver if(object_graph_tensor is null) { - tf.device("/cpu:0"); - object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + tf_with(ops.device("/cpu:0"), _ => + { + object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + }); } else { @@ -232,13 +234,15 @@ public class TrackableSaver Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); Dictionary file_prefix_feed_dict; - Tensor file_prefix_tensor; + Tensor file_prefix_tensor = null; if (graph_building) { if(_file_prefix_placeholder is null) { - tf.device("/cpu:0"); - _file_prefix_placeholder = constant_op.constant("model"); + _file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant("model"); + }); } file_prefix_tensor = _file_prefix_placeholder; file_prefix_feed_dict = new(); @@ -246,8 +250,10 @@ public class TrackableSaver } else { - tf.device("/cpu:0"); - file_prefix_tensor = constant_op.constant(save_path); + file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant(save_path); + }); file_prefix_feed_dict = null; } TrackableObjectGraph object_graph_proto = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index c383c219..a6aa7640 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -117,9 +117,11 @@ namespace Tensorflow.Checkpoint string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; - // tf python has code `with ops.device(restore_device):` here. - tf.device(restore_device); // may be risky. - var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + Tensor[] restored_tensors = null; + tf_with(ops.device(restore_device), _ => + { + restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + }); Dictionary> restored_tensor_dict = new(); int idx = 0; @@ -243,11 +245,14 @@ namespace Tensorflow.Checkpoint options = new CheckpointOptions(); } - tf.device("CPU"); // may be risky. - var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), + Tensor tmp_checkpoint_prefix = null; + tf_with(ops.device("CPU"), _ => + { + var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), constant_op.constant(".part"), constant_op.constant("_temp/part")); - var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); - IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); + IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + }); Operation save_fn() { @@ -269,16 +274,24 @@ namespace Tensorflow.Checkpoint var saver = pair.Value; last_device = device; // skip the extra process of device name because of lack of API. - tf.device(device); - var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + Tensor shard_prefix = null; + tf_with(ops.device(device), _ => + { + shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + }); saved_prefixes.Add(shard_prefix); - sharded_saves.Add(saver.save(shard_prefix, options)); + tf_with(ops.device(device), _ => + { + sharded_saves.Add(saver.save(shard_prefix, options)); + }); } using (var controller = ops.control_dependencies(sharded_saves.ToArray())) { string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; - tf.device(merge_device); - return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); + return tf_with(ops.device(merge_device), _ => + { + return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); + }); } } @@ -312,8 +325,9 @@ namespace Tensorflow.Checkpoint { var device = single_saver.Key; var saver = single_saver.Value; - tf.device(device); - var restored_tensor_dict = saver.restore(file_prefix, options); + tf_with(ops.device(device), _ => + { + var restored_tensor_dict = saver.restore(file_prefix, options); foreach(var pair in restored_tensor_dict) { @@ -405,21 +419,25 @@ namespace Tensorflow.Checkpoint private Tensor _traced_save(Tensor file_prefix) { var save_op = save(file_prefix); - tf.device("cpu:0"); - using (ops.control_dependencies(new object[]{ save_op })) + return tf_with(ops.device("cpu:0"), _ => { - return array_ops.identity(file_prefix); - } + return tf_with(ops.control_dependencies(new object[] { save_op }), __ => + { + return array_ops.identity(file_prefix); + }); + }); } private Tensor _traced_restore(Tensor file_prefix) { var restore_op = restore(file_prefix); - tf.device("cpu:0"); - using (ops.control_dependencies(restore_op.Values.ToArray())) + return tf_with(ops.device("cpu:0"), _ => { - return array_ops.identity(file_prefix); - } + return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => + { + return array_ops.identity(file_prefix); + }); + }); } public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) diff --git a/src/TensorFlowNET.Core/Contexts/Context.Device.cs b/src/TensorFlowNET.Core/Contexts/Context.Device.cs index 97c550e8..32e6682e 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.Device.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.Device.cs @@ -21,6 +21,7 @@ using Tensorflow.Eager; using static Tensorflow.Binding; using Google.Protobuf; using Tensorflow.Device; +using Tensorflow.Exceptions; using System.Collections.Generic; namespace Tensorflow.Contexts @@ -30,10 +31,30 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { + internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); + internal List _logical_devices = null; + internal List _context_devices = null; + ContextDevicePlacementPolicy _device_policy; bool _log_device_placement; + int _num_gpus; Dictionary _memory_growth_map = new Dictionary(); + public string DeviceName { get; set; } = ""; + public DeviceSpec DeviceSpec { get; set; } = null; + + internal List Devices + { + get + { + if(_context_devices is null) + { + throw new AssertionError("Context must be initialized first."); + } + return _context_devices; + } + } + public void log_device_placement(bool enable) { if (_handle != null) @@ -89,5 +110,57 @@ namespace Tensorflow.Contexts return results.ToArray(); } + + public EagerDeviceContext device(string name) + { + return new EagerDeviceContext(this, name); + } + + internal void _set_device(string device_name, DeviceSpec device_spec) + { + DeviceSpec = device_spec; + DeviceName = device_name; + } + + internal void _initialize_logical_devices() + { + List logical_devices = new(); + List context_devices = new(); + Status status = new(); + var device_list = c_api.TFE_ContextListDevices(_handle, status); + status.Check(true); + try + { + this._num_gpus = 0; + string current_job = null; + int current_task = -1; + for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) + { + var dev_name = c_api.TF_DeviceListName(device_list, i, status); + status.Check(true); + context_devices.Add(DeviceUtils.canonical_name(dev_name)); + var spec = DeviceSpec.from_string(dev_name); + if(spec.Job == "localhost") + { + spec = spec.replace(job: null, replica: -1, task: -1); + } + logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); + var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); + var dev_type = c_api.StringPiece(dev_type_memory); + status.Check(true); + if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) + { + _num_gpus++; + } + } + } + finally + { + _logical_devices = logical_devices; + _context_devices = context_devices; + } + } } + + public record class LogicalDevice(string name, string device_type); } diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 7fec1e5a..0507cc2f 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -34,7 +34,6 @@ namespace Tensorflow.Contexts public const int EAGER_MODE = 1; int defaultExecutionMode = EAGER_MODE; - public string DeviceName { get; set; } = ""; public string ScopeName { get; set; } = ""; bool initialized = false; ContextSwitchStack context_switches; @@ -81,6 +80,9 @@ namespace Tensorflow.Contexts if (initialized) return; + Debug.Assert(_context_devices is null); + + Config = MergeConfig(); FunctionCallOptions.Config = Config; var config_str = Config.ToByteArray(); var opts = new ContextOptions(); @@ -90,6 +92,7 @@ namespace Tensorflow.Contexts c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); _handle = c_api.TFE_NewContext(opts, status); status.Check(true); + _initialize_logical_devices(); initialized = true; } @@ -228,6 +231,7 @@ namespace Tensorflow.Contexts { c_api.TFE_ContextClearCaches(_handle); } + _device_parsing_cache.Clear(); } public static implicit operator SafeContextHandle(Context ctx) diff --git a/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs b/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs new file mode 100644 index 00000000..2d5f61cd --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Device; + +namespace Tensorflow.Contexts +{ + public class EagerDeviceContext : ITensorFlowObject + { + private Context _ctx; + private string _device_name; + private Stack<(string, DeviceSpec, DeviceSpec)> _stack; + + public EagerDeviceContext(Context ctx, string device_name) + { + _ctx = ctx; + _device_name = device_name; + _stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); + } + public void __enter__() + { + var ctx = _ctx; + var old_device_name = ctx.DeviceName; + var old_device_spec = ctx.DeviceSpec; + var new_device_name = _device_name; + var cache_key = (old_device_name, new_device_name); + DeviceSpec new_device_spec; + if (Context._device_parsing_cache.ContainsKey(cache_key)) + { + (new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; + } + else + { + if(new_device_name is not null) + { + var device_spec = DeviceSpec.from_string(new_device_name); + if (!string.IsNullOrEmpty(old_device_name)) + { + new_device_spec = new DeviceSpec(old_device_spec); + } + else + { + ctx.ensure_initialized(); + new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); + } + new_device_spec = new_device_spec.make_merged_spec(device_spec); + } + else + { + new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); + } + new_device_name = new_device_spec.ToString(); + Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); + } + ctx._set_device(new_device_name, new_device_spec); + _stack.Push((old_device_name, old_device_spec, new_device_spec)); + } + + public void __exit__() + { + var ctx = _ctx; + var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); + ctx._set_device(old_device_name, old_device_spec); + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Device/DeviceSpec.cs b/src/TensorFlowNET.Core/Device/DeviceSpec.cs new file mode 100644 index 00000000..f4ea8cf0 --- /dev/null +++ b/src/TensorFlowNET.Core/Device/DeviceSpec.cs @@ -0,0 +1,205 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.Device +{ + public class DeviceSpec + { + private static Dictionary _STRING_TO_COMPONENTS_CACHE = new(); + private static Dictionary _COMPONENTS_TO_STRING_CACHE = new(); + private string _job; + private int _replica; + private int _task; + private string _device_type; + private int _device_index; + private string _as_string; + + public string Job => _job; + public int Replica => _replica; + public int Task => _task; + public string DeviceType => _device_type; + public int DeviceIndex => _device_index; + + public DeviceSpec(string job = null, int replica = -1, int task = -1, + string device_type = null, int device_index = -1) + { + _job = job; + _replica = replica; + _task = task; + _device_type = device_type; + _device_index = device_index; + _as_string = _components_to_string(job, replica, task, device_type, _device_index); + + } + + public DeviceSpec(DeviceSpec other) + { + _job = other._job; + _replica = other._replica; + _task = other._task; + _device_type = other._device_type; + _device_index = other._device_index; + _as_string = other._as_string; + } + + protected DeviceSpec(Components com) + { + _job = com.Job; + _replica = com.Replica; + _task = com.Task; + _device_type = com.DeviceType; + _device_index = com.DeviceIndex; + _as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); + } + + public DeviceSpec replace(string job = null, int replica = -1, int task = -1, + string device_type = null, int device_index = -1) + { + job = job ?? _job; + replica = replica == -1 ? _replica : replica; + task = task == -1 ? _task : task; + device_type = device_type ?? _device_type; + device_index = device_index == -1 ? _device_index : device_index; + return new DeviceSpec(job, replica, task, device_type, device_index); + } + + public static DeviceSpec from_string(string spec) + { + var components = _string_to_components(spec); + return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); + } + + public DeviceSpec make_merged_spec(DeviceSpec dev) + { + return new DeviceSpec(_get_combined_properties(dev)); + } + + private Components _get_combined_properties(DeviceSpec dev) + { + return new Components( + dev.Job ?? _job, + dev.Replica == -1 ? _replica : dev.Replica, + dev.Task == -1 ? _task : dev.Task, + dev.DeviceType ?? _device_type, + dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex + ); + } + + private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) + { + var key = new Components(job, replica, task, device_type, device_index); + if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) + { + return cache_result; + } + + StringBuilder output = new(); + if(job is not null) + { + output.Append($"/job:{job}"); + } + if(replica != -1) + { + output.Append($"/replica:{replica}"); + } + if(task != -1) + { + output.Append($"/task:{task}"); + } + if (device_type is not null) + { + string device_index_string = "*"; + if (device_index != -1) + { + device_index_string = device_index.ToString(); + } + output.Append($"/device:{device_type}:{device_index_string}"); + } + var result = output.ToString(); + _COMPONENTS_TO_STRING_CACHE[key] = result; + return result; + } + + private static Components _string_to_components(string spec) + { + if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) + { + return cached_result; + } + var raw_spec = spec; + var splits = spec.Split('/').Select(x => x.Split(':')); + var valid_device_types = _get_valid_device_types(); + string job = null, device_type = null; + int replica = -1, task = -1, device_index = -1; + foreach (var y in splits) + { + var ly = y.Length; + if (ly > 0) + { + if(ly == 2 && y[0] == "job") + { + job = y[1]; + } + else if(ly == 2 && y[0] == "replica") + { + replica = int.Parse(y[1]); + } + else if(ly == 2 && y[0] == "task") + { + task = int.Parse(y[1]); + } + else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) + { + if (device_type is not null) + { + throw new ValueError($"Multiple device types are not allowed " + + $"while parsing the device spec: {spec}."); + } + device_type = y[0].ToUpper(); + if(ly == 2 && y[1] != "*") + { + device_index = int.Parse(y[1]); + } + } + else if(ly == 3 && y[0] == "device") + { + if(device_type is not null) + { + throw new ValueError($"Multiple device types are not allowed " + + $"while parsing the device spec: {spec}."); + } + device_type = y[1]; + if (y[2] != "*") + { + device_index = int.Parse(y[2]); + } + } + else if (y[0] != "") + { + throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + + $"while parsing the device spec: {spec}."); + } + } + } + + var output = new Components(job, replica, task, device_type, device_index); + _STRING_TO_COMPONENTS_CACHE[raw_spec] = output; + return output; + } + + private static HashSet _get_valid_device_types() + { + // TODO(Rinne): revise it to calling C API (need customized API). + return new HashSet(new string[] { "CPU", "GPU" }); + } + + public override string ToString() + { + return _as_string; + } + + protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); + } +} diff --git a/src/TensorFlowNET.Core/Device/DeviceUtils.cs b/src/TensorFlowNET.Core/Device/DeviceUtils.cs new file mode 100644 index 00000000..8f11e6c8 --- /dev/null +++ b/src/TensorFlowNET.Core/Device/DeviceUtils.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Device +{ + internal static class DeviceUtils + { + public static string canonical_name(string device) + { + if(device is null) + { + return ""; + } + return DeviceSpec.from_string(device).ToString(); + } + public static string canonical_name(DeviceSpec device) + { + if (device is null) + { + return ""; + } + return device.ToString(); + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index f443bcff..eb8df581 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -22,6 +22,7 @@ using System.Linq; using Tensorflow.Framework; using Tensorflow.Functions; using Tensorflow.Common.Extensions; +using Tensorflow.Graphs; using static Tensorflow.Binding; namespace Tensorflow @@ -344,9 +345,15 @@ namespace Tensorflow return op; } - public void device(string device_name) + public ITensorFlowObject device(string device_name) { - + return new GraphDeviceContext(this, device_name); + } + + private void add_device_to_stack(string device_name, int offset = 0) + { + // TODO(Rinne): deal with device spec. + int total_offset = offset + 1; } private void _create_op_helper(Operation op, bool compute_device = true) diff --git a/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs b/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs new file mode 100644 index 00000000..2754c2b3 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Graphs +{ + public class GraphDeviceContext : ITensorFlowObject + { + private Graph _graph; + + public GraphDeviceContext(Graph graph, string device_name) + { + _graph = graph; + } + + public void __enter__() + { + + } + + public void __exit__() + { + + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs index 2b864f90..9d69d5d0 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayer.cs @@ -17,7 +17,7 @@ namespace Tensorflow.Keras List TrainableVariables { get; } List TrainableWeights { get; } List NonTrainableWeights { get; } - List Weights { get; } + List Weights { get; set; } Shape OutputShape { get; } Shape BatchInputShape { get; } TensorShapeConfig BuildInputShape { get; } diff --git a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs index 87b595b6..bc4daf13 100644 --- a/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs +++ b/src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs @@ -84,6 +84,8 @@ namespace Tensorflow protected bool built = false; public bool Built => built; + List ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public RnnCell(bool trainable = true, string name = null, TF_DataType dtype = TF_DataType.DtInvalid, diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index 6f26e07b..2eecfabf 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -413,8 +413,10 @@ namespace Tensorflow { var variables_path = SavedModelUtils.get_variables_path(_export_dir); var saver = new TrackableSaver(new ObjectGraphView(get(0))); - tf.device("CPU"); - saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + tf_with(ops.device("CPU"), _ => + { + saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + }); LoadStatus load_status; if (_save_options.allow_partial_checkpoint) { @@ -600,14 +602,16 @@ namespace Tensorflow if (load_with_device) { - tf.device(saved_device); - return (new UninitializedVariable( - shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), - dtype: (TF_DataType)proto.Dtype, - name: name, - trainable: trainable, - aggregation: aggregation - ), setattr); + return tf_with(ops.device(saved_device), _ => + { + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + }); } else { diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 74ce4e8a..64728020 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -282,9 +282,11 @@ namespace Tensorflow BaseResourceVariable new_variable; if (save_options.experimental_variable_policy.save_variable_devices()) { - tf.device(this.Device); Debug.Assert(this is ResourceVariable); - new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + new_variable = tf_with(ops.device(this.Device), _ => + { + return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + }); } else { diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs index 637d0983..e2631244 100644 --- a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -9,7 +9,7 @@ namespace Tensorflow.Variables /// /// A variable with no initializer. /// - public sealed class UninitializedVariable: BaseResourceVariable, IVariableV1 + public sealed class UninitializedVariable : BaseResourceVariable, IVariableV1 { // TODO: complete the arg list. public UninitializedVariable( @@ -19,7 +19,7 @@ namespace Tensorflow.Variables TF_DataType dtype = TF_DataType.DtInvalid, VariableAggregation aggregation = VariableAggregation.None, Shape shape = null, - Tensor extra_handle_data = null) + Tensor extra_handle_data = null) { string unique_id = ""; string handle_name = ""; @@ -50,9 +50,12 @@ namespace Tensorflow.Variables { tf_with(ops.name_scope("Read"), _ => { - tf.device(created_handle.Device); - var value = gen_resource_variable_ops.read_variable_op(created_handle, dtype); - resource_variable_ops._maybe_set_handle_data(dtype, created_handle, value); + var value = tf_with(ops.device(created_handle.Device), _ => + { + var result = gen_resource_variable_ops.read_variable_op(created_handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, created_handle, result); + return result; + }); _graph_element = value; }); ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index c261f3ce..6d1385ca 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -584,6 +584,23 @@ namespace Tensorflow } + public static ITensorFlowObject device(string device_name) + { + if (tf.Context.executing_eagerly()) + { + return tf.Context.device(device_name); + } + //else if (ops.executing_eagerly_outside_functions()) + //{ + // throw new NotImplementedException(); + //} + else + { + return get_default_graph().device(device_name); + } + // TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. + } + public class NullContextManager: IDisposable { public void Dispose() diff --git a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs index 0aa5006c..73ccc87b 100644 --- a/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs +++ b/src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs @@ -77,7 +77,7 @@ public class EarlyStopping: ICallback // Restore the weights after first epoch if no progress is ever made. if (_restore_best_weights && _best_weights == null) { - _best_weights = _parameters.Model.TrainableWeights; + _best_weights = _parameters.Model.Weights; } _wait += 1; @@ -102,10 +102,8 @@ public class EarlyStopping: ICallback { Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}"); } + _parameters.Model.Weights = _best_weights; } - // Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet. - // TODO(Wanglongzhi2001): implement it. - // _parameters.Model.load_weights(best_weights); } } public void on_train_end() diff --git a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs index b7241226..4e2790ab 100644 --- a/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs +++ b/src/TensorFlowNET.Keras/Losses/SparseCategoricalCrossentropy.cs @@ -4,17 +4,21 @@ namespace Tensorflow.Keras.Losses { public class SparseCategoricalCrossentropy : LossFunctionWrapper, ILossFunc { + private bool _from_logits = false; public SparseCategoricalCrossentropy( bool from_logits = false, string reduction = null, string name = null) : - base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name){ } + base(reduction: reduction, name: name == null ? "sparse_categorical_crossentropy" : name) + { + _from_logits = from_logits; + } public override Tensor Apply(Tensor target, Tensor output, bool from_logits = false, int axis = -1) { target = tf.cast(target, dtype: TF_DataType.TF_INT64); - if (!from_logits) + if (!_from_logits) { var epsilon = tf.constant(KerasApi.keras.backend.epsilon(), output.dtype); output = tf.clip_by_value(output, epsilon, 1 - epsilon); diff --git a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs index f91f1fe7..3788e950 100644 --- a/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs +++ b/test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs @@ -12,6 +12,7 @@ namespace TensorFlowNET.Keras.UnitTest.SaveModel; [TestClass] public class SequentialModelLoad { + [Ignore] [TestMethod] public void SimpleModelFromAutoCompile() {