| @@ -53,8 +53,11 @@ public static class SaveUtilV1 | |||||
| var g = to_graph.as_default(); | var g = to_graph.as_default(); | ||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | ||||
| object_map, call_with_mapped_captures, saveables_cache); | object_map, call_with_mapped_captures, saveables_cache); | ||||
| tf.device("/cpu:0"); | |||||
| var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
| var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||||
| { | |||||
| // TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. | |||||
| return constant_op.constant(graph_proto.ToByteArray()); | |||||
| }); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | ||||
| g.Exit(); | g.Exit(); | ||||
| return (named_saveable_objects, registered_savers); | return (named_saveable_objects, registered_savers); | ||||
| @@ -65,8 +68,10 @@ public static class SaveUtilV1 | |||||
| { | { | ||||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | ||||
| object_map, call_with_mapped_captures, saveables_cache); | object_map, call_with_mapped_captures, saveables_cache); | ||||
| tf.device("/cpu:0"); | |||||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||||
| var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||||
| { | |||||
| return constant_op.constant(graph_proto.ToString()); | |||||
| }); | |||||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | ||||
| return (named_saveable_objects, registered_savers); | return (named_saveable_objects, registered_savers); | ||||
| } | } | ||||
| @@ -58,8 +58,10 @@ public class TrackableSaver | |||||
| if(object_graph_tensor is null) | if(object_graph_tensor is null) | ||||
| { | { | ||||
| tf.device("/cpu:0"); | |||||
| object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
| tf_with(ops.device("/cpu:0"), _ => | |||||
| { | |||||
| object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||||
| }); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -230,13 +232,15 @@ public class TrackableSaver | |||||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | ||||
| Dictionary<Tensor, string> file_prefix_feed_dict; | Dictionary<Tensor, string> file_prefix_feed_dict; | ||||
| Tensor file_prefix_tensor; | |||||
| Tensor file_prefix_tensor = null; | |||||
| if (graph_building) | if (graph_building) | ||||
| { | { | ||||
| if(_file_prefix_placeholder is null) | if(_file_prefix_placeholder is null) | ||||
| { | { | ||||
| tf.device("/cpu:0"); | |||||
| _file_prefix_placeholder = constant_op.constant("model"); | |||||
| _file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => | |||||
| { | |||||
| return constant_op.constant("model"); | |||||
| }); | |||||
| } | } | ||||
| file_prefix_tensor = _file_prefix_placeholder; | file_prefix_tensor = _file_prefix_placeholder; | ||||
| file_prefix_feed_dict = new(); | file_prefix_feed_dict = new(); | ||||
| @@ -244,8 +248,10 @@ public class TrackableSaver | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| tf.device("/cpu:0"); | |||||
| file_prefix_tensor = constant_op.constant(save_path); | |||||
| file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => | |||||
| { | |||||
| return constant_op.constant(save_path); | |||||
| }); | |||||
| file_prefix_feed_dict = null; | file_prefix_feed_dict = null; | ||||
| } | } | ||||
| TrackableObjectGraph object_graph_proto = new(); | TrackableObjectGraph object_graph_proto = new(); | ||||
| @@ -211,9 +211,11 @@ namespace Tensorflow.Checkpoint | |||||
| string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | ||||
| // tf python has code `with ops.device(restore_device):` here. | |||||
| tf.device(restore_device); // may be risky. | |||||
| var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
| Tensor[] restored_tensors = null; | |||||
| tf_with(ops.device(restore_device), _ => | |||||
| { | |||||
| restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||||
| }); | |||||
| Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | ||||
| int idx = 0; | int idx = 0; | ||||
| @@ -338,11 +340,14 @@ namespace Tensorflow.Checkpoint | |||||
| options = new CheckpointOptions(); | options = new CheckpointOptions(); | ||||
| } | } | ||||
| tf.device("CPU"); // may be risky. | |||||
| var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||||
| Tensor tmp_checkpoint_prefix = null; | |||||
| tf_with(ops.device("CPU"), _ => | |||||
| { | |||||
| var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||||
| constant_op.constant(".part"), constant_op.constant("_temp/part")); | constant_op.constant(".part"), constant_op.constant("_temp/part")); | ||||
| var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||||
| IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||||
| tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||||
| IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||||
| }); | |||||
| Operation save_fn() | Operation save_fn() | ||||
| { | { | ||||
| @@ -364,16 +369,24 @@ namespace Tensorflow.Checkpoint | |||||
| var saver = pair.Value; | var saver = pair.Value; | ||||
| last_device = device; | last_device = device; | ||||
| // skip the extra process of device name because of lack of API. | // skip the extra process of device name because of lack of API. | ||||
| tf.device(device); | |||||
| var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||||
| Tensor shard_prefix = null; | |||||
| tf_with(ops.device(device), _ => | |||||
| { | |||||
| shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||||
| }); | |||||
| saved_prefixes.Add(shard_prefix); | saved_prefixes.Add(shard_prefix); | ||||
| sharded_saves.Add(saver.save(shard_prefix, options)); | |||||
| tf_with(ops.device(device), _ => | |||||
| { | |||||
| sharded_saves.Add(saver.save(shard_prefix, options)); | |||||
| }); | |||||
| } | } | ||||
| using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | ||||
| { | { | ||||
| string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | ||||
| tf.device(merge_device); | |||||
| return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||||
| return tf_with(ops.device(merge_device), _ => | |||||
| { | |||||
| return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| @@ -407,54 +420,56 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| var device = single_saver.Key; | var device = single_saver.Key; | ||||
| var saver = single_saver.Value; | var saver = single_saver.Value; | ||||
| tf.device(device); | |||||
| var restored_tensor_dict = saver.restore(file_prefix, options); | |||||
| foreach(var pair in restored_tensor_dict) | |||||
| tf_with(ops.device(device), _ => | |||||
| { | { | ||||
| var checkpoint_key = pair.Key; | |||||
| var slice_and_tensor = pair.Value; | |||||
| foreach(var item in slice_and_tensor) | |||||
| var restored_tensor_dict = saver.restore(file_prefix, options); | |||||
| foreach (var pair in restored_tensor_dict) | |||||
| { | { | ||||
| var slice_spec = item.Key; | |||||
| var tensor = item.Value; | |||||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
| if (!string.IsNullOrEmpty(slice_spec)) | |||||
| var checkpoint_key = pair.Key; | |||||
| var slice_and_tensor = pair.Value; | |||||
| foreach (var item in slice_and_tensor) | |||||
| { | { | ||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
| var slice_spec = item.Key; | |||||
| var tensor = item.Value; | |||||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||||
| if (!string.IsNullOrEmpty(slice_spec)) | |||||
| { | { | ||||
| Dictionary<string, Tensor> dict = new(); | |||||
| dict[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
| { | |||||
| Dictionary<string, Tensor> dict = new(); | |||||
| dict[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||||
| } | |||||
| else | |||||
| { | |||||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
| } | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
| } | } | ||||
| } | |||||
| else | |||||
| { | |||||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||||
| } | |||||
| restore_fn_input_count[restore_fn]--; | |||||
| restore_fn_input_count[restore_fn]--; | |||||
| if (restore_fn_input_count[restore_fn] == 0) | |||||
| { | |||||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
| foreach(var input in restore_fn_inputs[restore_fn]) | |||||
| if (restore_fn_input_count[restore_fn] == 0) | |||||
| { | { | ||||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
| } | |||||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
| if(ret is IDictionary<string, Operation>) | |||||
| { | |||||
| var dict = (IDictionary<string, Operation>)ret; | |||||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
| foreach (var input in restore_fn_inputs[restore_fn]) | |||||
| { | |||||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
| } | |||||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
| if (ret is IDictionary<string, Operation>) | |||||
| { | |||||
| var dict = (IDictionary<string, Operation>)ret; | |||||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| }); | |||||
| } | } | ||||
| foreach(var item in _registered_savers) | foreach(var item in _registered_savers) | ||||
| @@ -500,21 +515,25 @@ namespace Tensorflow.Checkpoint | |||||
| private Tensor _traced_save(Tensor file_prefix) | private Tensor _traced_save(Tensor file_prefix) | ||||
| { | { | ||||
| var save_op = save(file_prefix); | var save_op = save(file_prefix); | ||||
| tf.device("cpu:0"); | |||||
| using (ops.control_dependencies(new object[]{ save_op })) | |||||
| return tf_with(ops.device("cpu:0"), _ => | |||||
| { | { | ||||
| return array_ops.identity(file_prefix); | |||||
| } | |||||
| return tf_with(ops.control_dependencies(new object[] { save_op }), __ => | |||||
| { | |||||
| return array_ops.identity(file_prefix); | |||||
| }); | |||||
| }); | |||||
| } | } | ||||
| private Tensor _traced_restore(Tensor file_prefix) | private Tensor _traced_restore(Tensor file_prefix) | ||||
| { | { | ||||
| var restore_op = restore(file_prefix); | var restore_op = restore(file_prefix); | ||||
| tf.device("cpu:0"); | |||||
| using (ops.control_dependencies(restore_op.Values.ToArray())) | |||||
| return tf_with(ops.device("cpu:0"), _ => | |||||
| { | { | ||||
| return array_ops.identity(file_prefix); | |||||
| } | |||||
| return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => | |||||
| { | |||||
| return array_ops.identity(file_prefix); | |||||
| }); | |||||
| }); | |||||
| } | } | ||||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | ||||
| @@ -21,6 +21,7 @@ using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| using Google.Protobuf; | using Google.Protobuf; | ||||
| using Tensorflow.Device; | using Tensorflow.Device; | ||||
| using Tensorflow.Exceptions; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| namespace Tensorflow.Contexts | namespace Tensorflow.Contexts | ||||
| @@ -30,10 +31,30 @@ namespace Tensorflow.Contexts | |||||
| /// </summary> | /// </summary> | ||||
| public sealed partial class Context | public sealed partial class Context | ||||
| { | { | ||||
| internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); | |||||
| internal List<LogicalDevice> _logical_devices = null; | |||||
| internal List<string> _context_devices = null; | |||||
| ContextDevicePlacementPolicy _device_policy; | ContextDevicePlacementPolicy _device_policy; | ||||
| bool _log_device_placement; | bool _log_device_placement; | ||||
| int _num_gpus; | |||||
| Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>(); | Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>(); | ||||
| public string DeviceName { get; set; } = ""; | |||||
| public DeviceSpec DeviceSpec { get; set; } = null; | |||||
| internal List<string> Devices | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_context_devices is null) | |||||
| { | |||||
| throw new AssertionError("Context must be initialized first."); | |||||
| } | |||||
| return _context_devices; | |||||
| } | |||||
| } | |||||
| public void log_device_placement(bool enable) | public void log_device_placement(bool enable) | ||||
| { | { | ||||
| if (_handle != null) | if (_handle != null) | ||||
| @@ -89,5 +110,57 @@ namespace Tensorflow.Contexts | |||||
| return results.ToArray(); | return results.ToArray(); | ||||
| } | } | ||||
| public EagerDeviceContext device(string name) | |||||
| { | |||||
| return new EagerDeviceContext(this, name); | |||||
| } | |||||
| internal void _set_device(string device_name, DeviceSpec device_spec) | |||||
| { | |||||
| DeviceSpec = device_spec; | |||||
| DeviceName = device_name; | |||||
| } | |||||
| internal void _initialize_logical_devices() | |||||
| { | |||||
| List<LogicalDevice> logical_devices = new(); | |||||
| List<string> context_devices = new(); | |||||
| Status status = new(); | |||||
| var device_list = c_api.TFE_ContextListDevices(_handle, status); | |||||
| status.Check(true); | |||||
| try | |||||
| { | |||||
| this._num_gpus = 0; | |||||
| string current_job = null; | |||||
| int current_task = -1; | |||||
| for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) | |||||
| { | |||||
| var dev_name = c_api.TF_DeviceListName(device_list, i, status); | |||||
| status.Check(true); | |||||
| context_devices.Add(DeviceUtils.canonical_name(dev_name)); | |||||
| var spec = DeviceSpec.from_string(dev_name); | |||||
| if(spec.Job == "localhost") | |||||
| { | |||||
| spec = spec.replace(job: null, replica: -1, task: -1); | |||||
| } | |||||
| logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); | |||||
| var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); | |||||
| var dev_type = c_api.StringPiece(dev_type_memory); | |||||
| status.Check(true); | |||||
| if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) | |||||
| { | |||||
| _num_gpus++; | |||||
| } | |||||
| } | |||||
| } | |||||
| finally | |||||
| { | |||||
| _logical_devices = logical_devices; | |||||
| _context_devices = context_devices; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| public record class LogicalDevice(string name, string device_type); | |||||
| } | } | ||||
| @@ -34,7 +34,6 @@ namespace Tensorflow.Contexts | |||||
| public const int EAGER_MODE = 1; | public const int EAGER_MODE = 1; | ||||
| int defaultExecutionMode = EAGER_MODE; | int defaultExecutionMode = EAGER_MODE; | ||||
| public string DeviceName { get; set; } = ""; | |||||
| public string ScopeName { get; set; } = ""; | public string ScopeName { get; set; } = ""; | ||||
| bool initialized = false; | bool initialized = false; | ||||
| ContextSwitchStack context_switches; | ContextSwitchStack context_switches; | ||||
| @@ -62,6 +61,8 @@ namespace Tensorflow.Contexts | |||||
| if (initialized) | if (initialized) | ||||
| return; | return; | ||||
| Debug.Assert(_context_devices is null); | |||||
| Config = MergeConfig(); | Config = MergeConfig(); | ||||
| FunctionCallOptions.Config = Config; | FunctionCallOptions.Config = Config; | ||||
| var config_str = Config.ToByteArray(); | var config_str = Config.ToByteArray(); | ||||
| @@ -72,6 +73,7 @@ namespace Tensorflow.Contexts | |||||
| c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | ||||
| _handle = c_api.TFE_NewContext(opts, status); | _handle = c_api.TFE_NewContext(opts, status); | ||||
| status.Check(true); | status.Check(true); | ||||
| _initialize_logical_devices(); | |||||
| initialized = true; | initialized = true; | ||||
| } | } | ||||
| @@ -174,6 +176,7 @@ namespace Tensorflow.Contexts | |||||
| { | { | ||||
| c_api.TFE_ContextClearCaches(_handle); | c_api.TFE_ContextClearCaches(_handle); | ||||
| } | } | ||||
| _device_parsing_cache.Clear(); | |||||
| } | } | ||||
| public static implicit operator SafeContextHandle(Context ctx) | public static implicit operator SafeContextHandle(Context ctx) | ||||
| @@ -0,0 +1,71 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Device; | |||||
| namespace Tensorflow.Contexts | |||||
| { | |||||
| public class EagerDeviceContext : ITensorFlowObject | |||||
| { | |||||
| private Context _ctx; | |||||
| private string _device_name; | |||||
| private Stack<(string, DeviceSpec, DeviceSpec)> _stack; | |||||
| public EagerDeviceContext(Context ctx, string device_name) | |||||
| { | |||||
| _ctx = ctx; | |||||
| _device_name = device_name; | |||||
| _stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); | |||||
| } | |||||
| public void __enter__() | |||||
| { | |||||
| var ctx = _ctx; | |||||
| var old_device_name = ctx.DeviceName; | |||||
| var old_device_spec = ctx.DeviceSpec; | |||||
| var new_device_name = _device_name; | |||||
| var cache_key = (old_device_name, new_device_name); | |||||
| DeviceSpec new_device_spec; | |||||
| if (Context._device_parsing_cache.ContainsKey(cache_key)) | |||||
| { | |||||
| (new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; | |||||
| } | |||||
| else | |||||
| { | |||||
| if(new_device_name is not null) | |||||
| { | |||||
| var device_spec = DeviceSpec.from_string(new_device_name); | |||||
| if (!string.IsNullOrEmpty(old_device_name)) | |||||
| { | |||||
| new_device_spec = new DeviceSpec(old_device_spec); | |||||
| } | |||||
| else | |||||
| { | |||||
| ctx.ensure_initialized(); | |||||
| new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||||
| } | |||||
| new_device_spec = new_device_spec.make_merged_spec(device_spec); | |||||
| } | |||||
| else | |||||
| { | |||||
| new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||||
| } | |||||
| new_device_name = new_device_spec.ToString(); | |||||
| Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); | |||||
| } | |||||
| ctx._set_device(new_device_name, new_device_spec); | |||||
| _stack.Push((old_device_name, old_device_spec, new_device_spec)); | |||||
| } | |||||
| public void __exit__() | |||||
| { | |||||
| var ctx = _ctx; | |||||
| var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); | |||||
| ctx._set_device(old_device_name, old_device_spec); | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,205 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using System.Threading.Tasks; | |||||
| namespace Tensorflow.Device | |||||
| { | |||||
| public class DeviceSpec | |||||
| { | |||||
| private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new(); | |||||
| private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new(); | |||||
| private string _job; | |||||
| private int _replica; | |||||
| private int _task; | |||||
| private string _device_type; | |||||
| private int _device_index; | |||||
| private string _as_string; | |||||
| public string Job => _job; | |||||
| public int Replica => _replica; | |||||
| public int Task => _task; | |||||
| public string DeviceType => _device_type; | |||||
| public int DeviceIndex => _device_index; | |||||
| public DeviceSpec(string job = null, int replica = -1, int task = -1, | |||||
| string device_type = null, int device_index = -1) | |||||
| { | |||||
| _job = job; | |||||
| _replica = replica; | |||||
| _task = task; | |||||
| _device_type = device_type; | |||||
| _device_index = device_index; | |||||
| _as_string = _components_to_string(job, replica, task, device_type, _device_index); | |||||
| } | |||||
| public DeviceSpec(DeviceSpec other) | |||||
| { | |||||
| _job = other._job; | |||||
| _replica = other._replica; | |||||
| _task = other._task; | |||||
| _device_type = other._device_type; | |||||
| _device_index = other._device_index; | |||||
| _as_string = other._as_string; | |||||
| } | |||||
| protected DeviceSpec(Components com) | |||||
| { | |||||
| _job = com.Job; | |||||
| _replica = com.Replica; | |||||
| _task = com.Task; | |||||
| _device_type = com.DeviceType; | |||||
| _device_index = com.DeviceIndex; | |||||
| _as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); | |||||
| } | |||||
| public DeviceSpec replace(string job = null, int replica = -1, int task = -1, | |||||
| string device_type = null, int device_index = -1) | |||||
| { | |||||
| job = job ?? _job; | |||||
| replica = replica == -1 ? _replica : replica; | |||||
| task = task == -1 ? _task : task; | |||||
| device_type = device_type ?? _device_type; | |||||
| device_index = device_index == -1 ? _device_index : device_index; | |||||
| return new DeviceSpec(job, replica, task, device_type, device_index); | |||||
| } | |||||
| public static DeviceSpec from_string(string spec) | |||||
| { | |||||
| var components = _string_to_components(spec); | |||||
| return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); | |||||
| } | |||||
| public DeviceSpec make_merged_spec(DeviceSpec dev) | |||||
| { | |||||
| return new DeviceSpec(_get_combined_properties(dev)); | |||||
| } | |||||
| private Components _get_combined_properties(DeviceSpec dev) | |||||
| { | |||||
| return new Components( | |||||
| dev.Job ?? _job, | |||||
| dev.Replica == -1 ? _replica : dev.Replica, | |||||
| dev.Task == -1 ? _task : dev.Task, | |||||
| dev.DeviceType ?? _device_type, | |||||
| dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex | |||||
| ); | |||||
| } | |||||
| private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) | |||||
| { | |||||
| var key = new Components(job, replica, task, device_type, device_index); | |||||
| if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) | |||||
| { | |||||
| return cache_result; | |||||
| } | |||||
| StringBuilder output = new(); | |||||
| if(job is not null) | |||||
| { | |||||
| output.Append($"/job:{job}"); | |||||
| } | |||||
| if(replica != -1) | |||||
| { | |||||
| output.Append($"/replica:{replica}"); | |||||
| } | |||||
| if(task != -1) | |||||
| { | |||||
| output.Append($"/task:{task}"); | |||||
| } | |||||
| if (device_type is not null) | |||||
| { | |||||
| string device_index_string = "*"; | |||||
| if (device_index != -1) | |||||
| { | |||||
| device_index_string = device_index.ToString(); | |||||
| } | |||||
| output.Append($"/device:{device_type}:{device_index_string}"); | |||||
| } | |||||
| var result = output.ToString(); | |||||
| _COMPONENTS_TO_STRING_CACHE[key] = result; | |||||
| return result; | |||||
| } | |||||
| private static Components _string_to_components(string spec) | |||||
| { | |||||
| if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) | |||||
| { | |||||
| return cached_result; | |||||
| } | |||||
| var raw_spec = spec; | |||||
| var splits = spec.Split('/').Select(x => x.Split(':')); | |||||
| var valid_device_types = _get_valid_device_types(); | |||||
| string job = null, device_type = null; | |||||
| int replica = -1, task = -1, device_index = -1; | |||||
| foreach (var y in splits) | |||||
| { | |||||
| var ly = y.Length; | |||||
| if (ly > 0) | |||||
| { | |||||
| if(ly == 2 && y[0] == "job") | |||||
| { | |||||
| job = y[1]; | |||||
| } | |||||
| else if(ly == 2 && y[0] == "replica") | |||||
| { | |||||
| replica = int.Parse(y[1]); | |||||
| } | |||||
| else if(ly == 2 && y[0] == "task") | |||||
| { | |||||
| task = int.Parse(y[1]); | |||||
| } | |||||
| else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) | |||||
| { | |||||
| if (device_type is not null) | |||||
| { | |||||
| throw new ValueError($"Multiple device types are not allowed " + | |||||
| $"while parsing the device spec: {spec}."); | |||||
| } | |||||
| device_type = y[0].ToUpper(); | |||||
| if(ly == 2 && y[1] != "*") | |||||
| { | |||||
| device_index = int.Parse(y[1]); | |||||
| } | |||||
| } | |||||
| else if(ly == 3 && y[0] == "device") | |||||
| { | |||||
| if(device_type is not null) | |||||
| { | |||||
| throw new ValueError($"Multiple device types are not allowed " + | |||||
| $"while parsing the device spec: {spec}."); | |||||
| } | |||||
| device_type = y[1]; | |||||
| if (y[2] != "*") | |||||
| { | |||||
| device_index = int.Parse(y[2]); | |||||
| } | |||||
| } | |||||
| else if (y[0] != "") | |||||
| { | |||||
| throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + | |||||
| $"while parsing the device spec: {spec}."); | |||||
| } | |||||
| } | |||||
| } | |||||
| var output = new Components(job, replica, task, device_type, device_index); | |||||
| _STRING_TO_COMPONENTS_CACHE[raw_spec] = output; | |||||
| return output; | |||||
| } | |||||
| private static HashSet<string> _get_valid_device_types() | |||||
| { | |||||
| // TODO(Rinne): revise it to calling C API (need customized API). | |||||
| return new HashSet<string>(new string[] { "CPU", "GPU" }); | |||||
| } | |||||
| public override string ToString() | |||||
| { | |||||
| return _as_string; | |||||
| } | |||||
| protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); | |||||
| } | |||||
| } | |||||
| @@ -0,0 +1,26 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Device | |||||
| { | |||||
| internal static class DeviceUtils | |||||
| { | |||||
| public static string canonical_name(string device) | |||||
| { | |||||
| if(device is null) | |||||
| { | |||||
| return ""; | |||||
| } | |||||
| return DeviceSpec.from_string(device).ToString(); | |||||
| } | |||||
| public static string canonical_name(DeviceSpec device) | |||||
| { | |||||
| if (device is null) | |||||
| { | |||||
| return ""; | |||||
| } | |||||
| return device.ToString(); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -19,6 +19,7 @@ using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Collections.Specialized; | using System.Collections.Specialized; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Graphs; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -294,9 +295,15 @@ namespace Tensorflow | |||||
| return op; | return op; | ||||
| } | } | ||||
| public void device(string device_name) | |||||
| public ITensorFlowObject device(string device_name) | |||||
| { | { | ||||
| return new GraphDeviceContext(this, device_name); | |||||
| } | |||||
| private void add_device_to_stack(string device_name, int offset = 0) | |||||
| { | |||||
| // TODO(Rinne): deal with device spec. | |||||
| int total_offset = offset + 1; | |||||
| } | } | ||||
| private void _create_op_helper(Operation op, bool compute_device = true) | private void _create_op_helper(Operation op, bool compute_device = true) | ||||
| @@ -0,0 +1,31 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow.Graphs | |||||
| { | |||||
| public class GraphDeviceContext : ITensorFlowObject | |||||
| { | |||||
| private Graph _graph; | |||||
| public GraphDeviceContext(Graph graph, string device_name) | |||||
| { | |||||
| _graph = graph; | |||||
| } | |||||
| public void __enter__() | |||||
| { | |||||
| } | |||||
| public void __exit__() | |||||
| { | |||||
| } | |||||
| public void Dispose() | |||||
| { | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -42,16 +42,20 @@ namespace Tensorflow | |||||
| _var_device = var.Device; | _var_device = var.Device; | ||||
| _var_shape = var.shape; | _var_shape = var.shape; | ||||
| Tensor _read_variable_closure(BaseResourceVariable v) | |||||
| Tensor? _read_variable_closure(BaseResourceVariable v) | |||||
| { | { | ||||
| tf.device(v.Device); | |||||
| if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
| return tf_with(ops.device(v.Device), _ => | |||||
| { | { | ||||
| return null; | |||||
| } | |||||
| var x = v.read_value_no_copy(); | |||||
| tf.device("/device:CPU:0"); | |||||
| return array_ops.identity(x); | |||||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| var x = v.read_value_no_copy(); | |||||
| return tf_with(ops.device("/device:CPU:0"), __ => | |||||
| { | |||||
| return array_ops.identity(x); | |||||
| }); | |||||
| }); | |||||
| } | } | ||||
| this.handle_op = var.Handle; | this.handle_op = var.Handle; | ||||
| @@ -412,8 +412,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| var variables_path = SavedModelUtils.get_variables_path(_export_dir); | var variables_path = SavedModelUtils.get_variables_path(_export_dir); | ||||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | var saver = new TrackableSaver(new ObjectGraphView(get(0))); | ||||
| tf.device("CPU"); | |||||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||||
| tf_with(ops.device("CPU"), _ => | |||||
| { | |||||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||||
| }); | |||||
| LoadStatus load_status; | LoadStatus load_status; | ||||
| if (_save_options.allow_partial_checkpoint) | if (_save_options.allow_partial_checkpoint) | ||||
| { | { | ||||
| @@ -598,14 +600,16 @@ namespace Tensorflow | |||||
| if (load_with_device) | if (load_with_device) | ||||
| { | { | ||||
| tf.device(saved_device); | |||||
| return (new UninitializedVariable( | |||||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
| dtype: (TF_DataType)proto.Dtype, | |||||
| name: name, | |||||
| trainable: trainable, | |||||
| aggregation: aggregation | |||||
| ), setattr); | |||||
| return tf_with(ops.device(saved_device), _ => | |||||
| { | |||||
| return (new UninitializedVariable( | |||||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||||
| dtype: (TF_DataType)proto.Dtype, | |||||
| name: name, | |||||
| trainable: trainable, | |||||
| aggregation: aggregation | |||||
| ), setattr); | |||||
| }); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -266,9 +266,11 @@ namespace Tensorflow | |||||
| BaseResourceVariable new_variable; | BaseResourceVariable new_variable; | ||||
| if (save_options.experimental_variable_policy.save_variable_devices()) | if (save_options.experimental_variable_policy.save_variable_devices()) | ||||
| { | { | ||||
| tf.device(this.Device); | |||||
| Debug.Assert(this is ResourceVariable); | Debug.Assert(this is ResourceVariable); | ||||
| new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||||
| new_variable = tf_with(ops.device(this.Device), _ => | |||||
| { | |||||
| return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||||
| }); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -49,9 +49,12 @@ namespace Tensorflow.Variables | |||||
| { | { | ||||
| tf_with(ops.name_scope("Read"), _ => | tf_with(ops.name_scope("Read"), _ => | ||||
| { | { | ||||
| tf.device(handle.Device); | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| // _maybe_set_handle_data(dtype, handle, value) | |||||
| var value = tf_with(ops.device(handle.Device), _ => | |||||
| { | |||||
| var result = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| // TODO(Rinne): _maybe_set_handle_data(dtype, handle, value) | |||||
| return result; | |||||
| }); | |||||
| _graph_element = value; | _graph_element = value; | ||||
| }); | }); | ||||
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | ||||
| @@ -577,6 +577,23 @@ namespace Tensorflow | |||||
| } | } | ||||
| public static ITensorFlowObject device(string device_name) | |||||
| { | |||||
| if (tf.Context.executing_eagerly()) | |||||
| { | |||||
| return tf.Context.device(device_name); | |||||
| } | |||||
| //else if (ops.executing_eagerly_outside_functions()) | |||||
| //{ | |||||
| // throw new NotImplementedException(); | |||||
| //} | |||||
| else | |||||
| { | |||||
| return get_default_graph().device(device_name); | |||||
| } | |||||
| // TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. | |||||
| } | |||||
| public class NullContextManager: IDisposable | public class NullContextManager: IDisposable | ||||
| { | { | ||||
| public void Dispose() | public void Dispose() | ||||