Fix the error when saving model with GPU.tags/v0.100.5-BERT-load
| @@ -53,8 +53,11 @@ public static class SaveUtilV1 | |||
| var g = to_graph.as_default(); | |||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
| object_map, call_with_mapped_captures, saveables_cache); | |||
| tf.device("/cpu:0"); | |||
| var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
| var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
| { | |||
| // TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. | |||
| return constant_op.constant(graph_proto.ToByteArray()); | |||
| }); | |||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| g.Exit(); | |||
| return (named_saveable_objects, registered_savers); | |||
| @@ -65,8 +68,10 @@ public static class SaveUtilV1 | |||
| { | |||
| var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view, | |||
| object_map, call_with_mapped_captures, saveables_cache); | |||
| tf.device("/cpu:0"); | |||
| var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING); | |||
| var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
| { | |||
| return constant_op.constant(graph_proto.ToString()); | |||
| }); | |||
| named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); | |||
| return (named_saveable_objects, registered_savers); | |||
| } | |||
| @@ -58,8 +58,10 @@ public class TrackableSaver | |||
| if(object_graph_tensor is null) | |||
| { | |||
| tf.device("/cpu:0"); | |||
| object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
| tf_with(ops.device("/cpu:0"), _ => | |||
| { | |||
| object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| @@ -230,13 +232,15 @@ public class TrackableSaver | |||
| Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); | |||
| Dictionary<Tensor, string> file_prefix_feed_dict; | |||
| Tensor file_prefix_tensor; | |||
| Tensor file_prefix_tensor = null; | |||
| if (graph_building) | |||
| { | |||
| if(_file_prefix_placeholder is null) | |||
| { | |||
| tf.device("/cpu:0"); | |||
| _file_prefix_placeholder = constant_op.constant("model"); | |||
| _file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => | |||
| { | |||
| return constant_op.constant("model"); | |||
| }); | |||
| } | |||
| file_prefix_tensor = _file_prefix_placeholder; | |||
| file_prefix_feed_dict = new(); | |||
| @@ -244,8 +248,10 @@ public class TrackableSaver | |||
| } | |||
| else | |||
| { | |||
| tf.device("/cpu:0"); | |||
| file_prefix_tensor = constant_op.constant(save_path); | |||
| file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => | |||
| { | |||
| return constant_op.constant(save_path); | |||
| }); | |||
| file_prefix_feed_dict = null; | |||
| } | |||
| TrackableObjectGraph object_graph_proto = new(); | |||
| @@ -211,9 +211,11 @@ namespace Tensorflow.Checkpoint | |||
| string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; | |||
| // tf python has code `with ops.device(restore_device):` here. | |||
| tf.device(restore_device); // may be risky. | |||
| var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
| Tensor[] restored_tensors = null; | |||
| tf_with(ops.device(restore_device), _ => | |||
| { | |||
| restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); | |||
| }); | |||
| Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new(); | |||
| int idx = 0; | |||
| @@ -338,11 +340,14 @@ namespace Tensorflow.Checkpoint | |||
| options = new CheckpointOptions(); | |||
| } | |||
| tf.device("CPU"); // may be risky. | |||
| var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||
| Tensor tmp_checkpoint_prefix = null; | |||
| tf_with(ops.device("CPU"), _ => | |||
| { | |||
| var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), | |||
| constant_op.constant(".part"), constant_op.constant("_temp/part")); | |||
| var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||
| IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
| tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); | |||
| IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); | |||
| }); | |||
| Operation save_fn() | |||
| { | |||
| @@ -364,16 +369,24 @@ namespace Tensorflow.Checkpoint | |||
| var saver = pair.Value; | |||
| last_device = device; | |||
| // skip the extra process of device name because of lack of API. | |||
| tf.device(device); | |||
| var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||
| Tensor shard_prefix = null; | |||
| tf_with(ops.device(device), _ => | |||
| { | |||
| shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); | |||
| }); | |||
| saved_prefixes.Add(shard_prefix); | |||
| sharded_saves.Add(saver.save(shard_prefix, options)); | |||
| tf_with(ops.device(device), _ => | |||
| { | |||
| sharded_saves.Add(saver.save(shard_prefix, options)); | |||
| }); | |||
| } | |||
| using (var controller = ops.control_dependencies(sharded_saves.ToArray())) | |||
| { | |||
| string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; | |||
| tf.device(merge_device); | |||
| return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||
| return tf_with(ops.device(merge_device), _ => | |||
| { | |||
| return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); | |||
| }); | |||
| } | |||
| } | |||
| @@ -407,54 +420,56 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var device = single_saver.Key; | |||
| var saver = single_saver.Value; | |||
| tf.device(device); | |||
| var restored_tensor_dict = saver.restore(file_prefix, options); | |||
| foreach(var pair in restored_tensor_dict) | |||
| tf_with(ops.device(device), _ => | |||
| { | |||
| var checkpoint_key = pair.Key; | |||
| var slice_and_tensor = pair.Value; | |||
| foreach(var item in slice_and_tensor) | |||
| var restored_tensor_dict = saver.restore(file_prefix, options); | |||
| foreach (var pair in restored_tensor_dict) | |||
| { | |||
| var slice_spec = item.Key; | |||
| var tensor = item.Value; | |||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| var checkpoint_key = pair.Key; | |||
| var slice_and_tensor = pair.Value; | |||
| foreach (var item in slice_and_tensor) | |||
| { | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| var slice_spec = item.Key; | |||
| var tensor = item.Value; | |||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>()); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| Dictionary<string, Tensor> dict = new(); | |||
| dict[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| { | |||
| Dictionary<string, Tensor> dict = new(); | |||
| dict[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict); | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor); | |||
| } | |||
| restore_fn_input_count[restore_fn]--; | |||
| restore_fn_input_count[restore_fn]--; | |||
| if (restore_fn_input_count[restore_fn] == 0) | |||
| { | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| foreach(var input in restore_fn_inputs[restore_fn]) | |||
| if (restore_fn_input_count[restore_fn] == 0) | |||
| { | |||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
| } | |||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||
| if(ret is IDictionary<string, Operation>) | |||
| { | |||
| var dict = (IDictionary<string, Operation>)ret; | |||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||
| Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| foreach (var input in restore_fn_inputs[restore_fn]) | |||
| { | |||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
| } | |||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||
| if (ret is IDictionary<string, Operation>) | |||
| { | |||
| var dict = (IDictionary<string, Operation>)ret; | |||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| } | |||
| foreach(var item in _registered_savers) | |||
| @@ -500,21 +515,25 @@ namespace Tensorflow.Checkpoint | |||
| private Tensor _traced_save(Tensor file_prefix) | |||
| { | |||
| var save_op = save(file_prefix); | |||
| tf.device("cpu:0"); | |||
| using (ops.control_dependencies(new object[]{ save_op })) | |||
| return tf_with(ops.device("cpu:0"), _ => | |||
| { | |||
| return array_ops.identity(file_prefix); | |||
| } | |||
| return tf_with(ops.control_dependencies(new object[] { save_op }), __ => | |||
| { | |||
| return array_ops.identity(file_prefix); | |||
| }); | |||
| }); | |||
| } | |||
| private Tensor _traced_restore(Tensor file_prefix) | |||
| { | |||
| var restore_op = restore(file_prefix); | |||
| tf.device("cpu:0"); | |||
| using (ops.control_dependencies(restore_op.Values.ToArray())) | |||
| return tf_with(ops.device("cpu:0"), _ => | |||
| { | |||
| return array_ops.identity(file_prefix); | |||
| } | |||
| return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => | |||
| { | |||
| return array_ops.identity(file_prefix); | |||
| }); | |||
| }); | |||
| } | |||
| public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false) | |||
| @@ -21,6 +21,7 @@ using Tensorflow.Eager; | |||
| using static Tensorflow.Binding; | |||
| using Google.Protobuf; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Exceptions; | |||
| using System.Collections.Generic; | |||
| namespace Tensorflow.Contexts | |||
| @@ -30,10 +31,30 @@ namespace Tensorflow.Contexts | |||
| /// </summary> | |||
| public sealed partial class Context | |||
| { | |||
| internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); | |||
| internal List<LogicalDevice> _logical_devices = null; | |||
| internal List<string> _context_devices = null; | |||
| ContextDevicePlacementPolicy _device_policy; | |||
| bool _log_device_placement; | |||
| int _num_gpus; | |||
| Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>(); | |||
| public string DeviceName { get; set; } = ""; | |||
| public DeviceSpec DeviceSpec { get; set; } = null; | |||
| internal List<string> Devices | |||
| { | |||
| get | |||
| { | |||
| if(_context_devices is null) | |||
| { | |||
| throw new AssertionError("Context must be initialized first."); | |||
| } | |||
| return _context_devices; | |||
| } | |||
| } | |||
| public void log_device_placement(bool enable) | |||
| { | |||
| if (_handle != null) | |||
| @@ -89,5 +110,57 @@ namespace Tensorflow.Contexts | |||
| return results.ToArray(); | |||
| } | |||
| public EagerDeviceContext device(string name) | |||
| { | |||
| return new EagerDeviceContext(this, name); | |||
| } | |||
| internal void _set_device(string device_name, DeviceSpec device_spec) | |||
| { | |||
| DeviceSpec = device_spec; | |||
| DeviceName = device_name; | |||
| } | |||
| internal void _initialize_logical_devices() | |||
| { | |||
| List<LogicalDevice> logical_devices = new(); | |||
| List<string> context_devices = new(); | |||
| Status status = new(); | |||
| var device_list = c_api.TFE_ContextListDevices(_handle, status); | |||
| status.Check(true); | |||
| try | |||
| { | |||
| this._num_gpus = 0; | |||
| string current_job = null; | |||
| int current_task = -1; | |||
| for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) | |||
| { | |||
| var dev_name = c_api.TF_DeviceListName(device_list, i, status); | |||
| status.Check(true); | |||
| context_devices.Add(DeviceUtils.canonical_name(dev_name)); | |||
| var spec = DeviceSpec.from_string(dev_name); | |||
| if(spec.Job == "localhost") | |||
| { | |||
| spec = spec.replace(job: null, replica: -1, task: -1); | |||
| } | |||
| logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); | |||
| var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); | |||
| var dev_type = c_api.StringPiece(dev_type_memory); | |||
| status.Check(true); | |||
| if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) | |||
| { | |||
| _num_gpus++; | |||
| } | |||
| } | |||
| } | |||
| finally | |||
| { | |||
| _logical_devices = logical_devices; | |||
| _context_devices = context_devices; | |||
| } | |||
| } | |||
| } | |||
| public record class LogicalDevice(string name, string device_type); | |||
| } | |||
| @@ -34,7 +34,6 @@ namespace Tensorflow.Contexts | |||
| public const int EAGER_MODE = 1; | |||
| int defaultExecutionMode = EAGER_MODE; | |||
| public string DeviceName { get; set; } = ""; | |||
| public string ScopeName { get; set; } = ""; | |||
| bool initialized = false; | |||
| ContextSwitchStack context_switches; | |||
| @@ -62,6 +61,8 @@ namespace Tensorflow.Contexts | |||
| if (initialized) | |||
| return; | |||
| Debug.Assert(_context_devices is null); | |||
| Config = MergeConfig(); | |||
| FunctionCallOptions.Config = Config; | |||
| var config_str = Config.ToByteArray(); | |||
| @@ -72,6 +73,7 @@ namespace Tensorflow.Contexts | |||
| c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); | |||
| _handle = c_api.TFE_NewContext(opts, status); | |||
| status.Check(true); | |||
| _initialize_logical_devices(); | |||
| initialized = true; | |||
| } | |||
| @@ -174,6 +176,7 @@ namespace Tensorflow.Contexts | |||
| { | |||
| c_api.TFE_ContextClearCaches(_handle); | |||
| } | |||
| _device_parsing_cache.Clear(); | |||
| } | |||
| public static implicit operator SafeContextHandle(Context ctx) | |||
| @@ -0,0 +1,71 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Device; | |||
| namespace Tensorflow.Contexts | |||
| { | |||
| public class EagerDeviceContext : ITensorFlowObject | |||
| { | |||
| private Context _ctx; | |||
| private string _device_name; | |||
| private Stack<(string, DeviceSpec, DeviceSpec)> _stack; | |||
| public EagerDeviceContext(Context ctx, string device_name) | |||
| { | |||
| _ctx = ctx; | |||
| _device_name = device_name; | |||
| _stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); | |||
| } | |||
| public void __enter__() | |||
| { | |||
| var ctx = _ctx; | |||
| var old_device_name = ctx.DeviceName; | |||
| var old_device_spec = ctx.DeviceSpec; | |||
| var new_device_name = _device_name; | |||
| var cache_key = (old_device_name, new_device_name); | |||
| DeviceSpec new_device_spec; | |||
| if (Context._device_parsing_cache.ContainsKey(cache_key)) | |||
| { | |||
| (new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; | |||
| } | |||
| else | |||
| { | |||
| if(new_device_name is not null) | |||
| { | |||
| var device_spec = DeviceSpec.from_string(new_device_name); | |||
| if (!string.IsNullOrEmpty(old_device_name)) | |||
| { | |||
| new_device_spec = new DeviceSpec(old_device_spec); | |||
| } | |||
| else | |||
| { | |||
| ctx.ensure_initialized(); | |||
| new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||
| } | |||
| new_device_spec = new_device_spec.make_merged_spec(device_spec); | |||
| } | |||
| else | |||
| { | |||
| new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); | |||
| } | |||
| new_device_name = new_device_spec.ToString(); | |||
| Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); | |||
| } | |||
| ctx._set_device(new_device_name, new_device_spec); | |||
| _stack.Push((old_device_name, old_device_spec, new_device_spec)); | |||
| } | |||
| public void __exit__() | |||
| { | |||
| var ctx = _ctx; | |||
| var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); | |||
| ctx._set_device(old_device_name, old_device_spec); | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,205 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| namespace Tensorflow.Device | |||
| { | |||
| public class DeviceSpec | |||
| { | |||
| private static Dictionary<string, Components> _STRING_TO_COMPONENTS_CACHE = new(); | |||
| private static Dictionary<Components, string> _COMPONENTS_TO_STRING_CACHE = new(); | |||
| private string _job; | |||
| private int _replica; | |||
| private int _task; | |||
| private string _device_type; | |||
| private int _device_index; | |||
| private string _as_string; | |||
| public string Job => _job; | |||
| public int Replica => _replica; | |||
| public int Task => _task; | |||
| public string DeviceType => _device_type; | |||
| public int DeviceIndex => _device_index; | |||
| public DeviceSpec(string job = null, int replica = -1, int task = -1, | |||
| string device_type = null, int device_index = -1) | |||
| { | |||
| _job = job; | |||
| _replica = replica; | |||
| _task = task; | |||
| _device_type = device_type; | |||
| _device_index = device_index; | |||
| _as_string = _components_to_string(job, replica, task, device_type, _device_index); | |||
| } | |||
| public DeviceSpec(DeviceSpec other) | |||
| { | |||
| _job = other._job; | |||
| _replica = other._replica; | |||
| _task = other._task; | |||
| _device_type = other._device_type; | |||
| _device_index = other._device_index; | |||
| _as_string = other._as_string; | |||
| } | |||
| protected DeviceSpec(Components com) | |||
| { | |||
| _job = com.Job; | |||
| _replica = com.Replica; | |||
| _task = com.Task; | |||
| _device_type = com.DeviceType; | |||
| _device_index = com.DeviceIndex; | |||
| _as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); | |||
| } | |||
| public DeviceSpec replace(string job = null, int replica = -1, int task = -1, | |||
| string device_type = null, int device_index = -1) | |||
| { | |||
| job = job ?? _job; | |||
| replica = replica == -1 ? _replica : replica; | |||
| task = task == -1 ? _task : task; | |||
| device_type = device_type ?? _device_type; | |||
| device_index = device_index == -1 ? _device_index : device_index; | |||
| return new DeviceSpec(job, replica, task, device_type, device_index); | |||
| } | |||
| public static DeviceSpec from_string(string spec) | |||
| { | |||
| var components = _string_to_components(spec); | |||
| return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); | |||
| } | |||
| public DeviceSpec make_merged_spec(DeviceSpec dev) | |||
| { | |||
| return new DeviceSpec(_get_combined_properties(dev)); | |||
| } | |||
| private Components _get_combined_properties(DeviceSpec dev) | |||
| { | |||
| return new Components( | |||
| dev.Job ?? _job, | |||
| dev.Replica == -1 ? _replica : dev.Replica, | |||
| dev.Task == -1 ? _task : dev.Task, | |||
| dev.DeviceType ?? _device_type, | |||
| dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex | |||
| ); | |||
| } | |||
| private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) | |||
| { | |||
| var key = new Components(job, replica, task, device_type, device_index); | |||
| if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) | |||
| { | |||
| return cache_result; | |||
| } | |||
| StringBuilder output = new(); | |||
| if(job is not null) | |||
| { | |||
| output.Append($"/job:{job}"); | |||
| } | |||
| if(replica != -1) | |||
| { | |||
| output.Append($"/replica:{replica}"); | |||
| } | |||
| if(task != -1) | |||
| { | |||
| output.Append($"/task:{task}"); | |||
| } | |||
| if (device_type is not null) | |||
| { | |||
| string device_index_string = "*"; | |||
| if (device_index != -1) | |||
| { | |||
| device_index_string = device_index.ToString(); | |||
| } | |||
| output.Append($"/device:{device_type}:{device_index_string}"); | |||
| } | |||
| var result = output.ToString(); | |||
| _COMPONENTS_TO_STRING_CACHE[key] = result; | |||
| return result; | |||
| } | |||
| private static Components _string_to_components(string spec) | |||
| { | |||
| if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) | |||
| { | |||
| return cached_result; | |||
| } | |||
| var raw_spec = spec; | |||
| var splits = spec.Split('/').Select(x => x.Split(':')); | |||
| var valid_device_types = _get_valid_device_types(); | |||
| string job = null, device_type = null; | |||
| int replica = -1, task = -1, device_index = -1; | |||
| foreach (var y in splits) | |||
| { | |||
| var ly = y.Length; | |||
| if (ly > 0) | |||
| { | |||
| if(ly == 2 && y[0] == "job") | |||
| { | |||
| job = y[1]; | |||
| } | |||
| else if(ly == 2 && y[0] == "replica") | |||
| { | |||
| replica = int.Parse(y[1]); | |||
| } | |||
| else if(ly == 2 && y[0] == "task") | |||
| { | |||
| task = int.Parse(y[1]); | |||
| } | |||
| else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) | |||
| { | |||
| if (device_type is not null) | |||
| { | |||
| throw new ValueError($"Multiple device types are not allowed " + | |||
| $"while parsing the device spec: {spec}."); | |||
| } | |||
| device_type = y[0].ToUpper(); | |||
| if(ly == 2 && y[1] != "*") | |||
| { | |||
| device_index = int.Parse(y[1]); | |||
| } | |||
| } | |||
| else if(ly == 3 && y[0] == "device") | |||
| { | |||
| if(device_type is not null) | |||
| { | |||
| throw new ValueError($"Multiple device types are not allowed " + | |||
| $"while parsing the device spec: {spec}."); | |||
| } | |||
| device_type = y[1]; | |||
| if (y[2] != "*") | |||
| { | |||
| device_index = int.Parse(y[2]); | |||
| } | |||
| } | |||
| else if (y[0] != "") | |||
| { | |||
| throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + | |||
| $"while parsing the device spec: {spec}."); | |||
| } | |||
| } | |||
| } | |||
| var output = new Components(job, replica, task, device_type, device_index); | |||
| _STRING_TO_COMPONENTS_CACHE[raw_spec] = output; | |||
| return output; | |||
| } | |||
| private static HashSet<string> _get_valid_device_types() | |||
| { | |||
| // TODO(Rinne): revise it to calling C API (need customized API). | |||
| return new HashSet<string>(new string[] { "CPU", "GPU" }); | |||
| } | |||
| public override string ToString() | |||
| { | |||
| return _as_string; | |||
| } | |||
| protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); | |||
| } | |||
| } | |||
| @@ -0,0 +1,26 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Device | |||
| { | |||
| internal static class DeviceUtils | |||
| { | |||
| public static string canonical_name(string device) | |||
| { | |||
| if(device is null) | |||
| { | |||
| return ""; | |||
| } | |||
| return DeviceSpec.from_string(device).ToString(); | |||
| } | |||
| public static string canonical_name(DeviceSpec device) | |||
| { | |||
| if (device is null) | |||
| { | |||
| return ""; | |||
| } | |||
| return device.ToString(); | |||
| } | |||
| } | |||
| } | |||
| @@ -19,6 +19,7 @@ using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Collections.Specialized; | |||
| using System.Linq; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -294,9 +295,15 @@ namespace Tensorflow | |||
| return op; | |||
| } | |||
| public void device(string device_name) | |||
| public ITensorFlowObject device(string device_name) | |||
| { | |||
| return new GraphDeviceContext(this, device_name); | |||
| } | |||
| private void add_device_to_stack(string device_name, int offset = 0) | |||
| { | |||
| // TODO(Rinne): deal with device spec. | |||
| int total_offset = offset + 1; | |||
| } | |||
| private void _create_op_helper(Operation op, bool compute_device = true) | |||
| @@ -0,0 +1,31 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| namespace Tensorflow.Graphs | |||
| { | |||
| public class GraphDeviceContext : ITensorFlowObject | |||
| { | |||
| private Graph _graph; | |||
| public GraphDeviceContext(Graph graph, string device_name) | |||
| { | |||
| _graph = graph; | |||
| } | |||
| public void __enter__() | |||
| { | |||
| } | |||
| public void __exit__() | |||
| { | |||
| } | |||
| public void Dispose() | |||
| { | |||
| } | |||
| } | |||
| } | |||
| @@ -42,16 +42,20 @@ namespace Tensorflow | |||
| _var_device = var.Device; | |||
| _var_shape = var.shape; | |||
| Tensor _read_variable_closure(BaseResourceVariable v) | |||
| Tensor? _read_variable_closure(BaseResourceVariable v) | |||
| { | |||
| tf.device(v.Device); | |||
| if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
| return tf_with(ops.device(v.Device), _ => | |||
| { | |||
| return null; | |||
| } | |||
| var x = v.read_value_no_copy(); | |||
| tf.device("/device:CPU:0"); | |||
| return array_ops.identity(x); | |||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
| { | |||
| return null; | |||
| } | |||
| var x = v.read_value_no_copy(); | |||
| return tf_with(ops.device("/device:CPU:0"), __ => | |||
| { | |||
| return array_ops.identity(x); | |||
| }); | |||
| }); | |||
| } | |||
| this.handle_op = var.Handle; | |||
| @@ -412,8 +412,10 @@ namespace Tensorflow | |||
| { | |||
| var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
| tf.device("CPU"); | |||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
| tf_with(ops.device("CPU"), _ => | |||
| { | |||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
| }); | |||
| LoadStatus load_status; | |||
| if (_save_options.allow_partial_checkpoint) | |||
| { | |||
| @@ -598,14 +600,16 @@ namespace Tensorflow | |||
| if (load_with_device) | |||
| { | |||
| tf.device(saved_device); | |||
| return (new UninitializedVariable( | |||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
| dtype: (TF_DataType)proto.Dtype, | |||
| name: name, | |||
| trainable: trainable, | |||
| aggregation: aggregation | |||
| ), setattr); | |||
| return tf_with(ops.device(saved_device), _ => | |||
| { | |||
| return (new UninitializedVariable( | |||
| shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), | |||
| dtype: (TF_DataType)proto.Dtype, | |||
| name: name, | |||
| trainable: trainable, | |||
| aggregation: aggregation | |||
| ), setattr); | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| @@ -266,9 +266,11 @@ namespace Tensorflow | |||
| BaseResourceVariable new_variable; | |||
| if (save_options.experimental_variable_policy.save_variable_devices()) | |||
| { | |||
| tf.device(this.Device); | |||
| Debug.Assert(this is ResourceVariable); | |||
| new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
| new_variable = tf_with(ops.device(this.Device), _ => | |||
| { | |||
| return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| @@ -49,9 +49,12 @@ namespace Tensorflow.Variables | |||
| { | |||
| tf_with(ops.name_scope("Read"), _ => | |||
| { | |||
| tf.device(handle.Device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| // _maybe_set_handle_data(dtype, handle, value) | |||
| var value = tf_with(ops.device(handle.Device), _ => | |||
| { | |||
| var result = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| // TODO(Rinne): _maybe_set_handle_data(dtype, handle, value) | |||
| return result; | |||
| }); | |||
| _graph_element = value; | |||
| }); | |||
| ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); | |||
| @@ -577,6 +577,23 @@ namespace Tensorflow | |||
| } | |||
| public static ITensorFlowObject device(string device_name) | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| return tf.Context.device(device_name); | |||
| } | |||
| //else if (ops.executing_eagerly_outside_functions()) | |||
| //{ | |||
| // throw new NotImplementedException(); | |||
| //} | |||
| else | |||
| { | |||
| return get_default_graph().device(device_name); | |||
| } | |||
| // TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. | |||
| } | |||
| public class NullContextManager: IDisposable | |||
| { | |||
| public void Dispose() | |||