From a59ebaeea41f6b304500e15395514d40c81d9d72 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Tue, 18 Apr 2023 18:22:51 +0800 Subject: [PATCH] Fix the errors caused by branch merge. --- .../Checkpoint/functional_saver.cs | 69 ++++++++++--------- .../Contexts/Context.Device.cs | 8 +++ src/TensorFlowNET.Core/Eager/c_api.eager.cs | 3 + .../Saving/ResourceVariableSaveable.cs | 26 ++++--- .../Saving/saveable_object_util.py.cs | 13 ++++ .../Variables/ResourceVariable.cs | 33 ++++++--- 6 files changed, 99 insertions(+), 53 deletions(-) diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index a6aa7640..211d7d6f 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -208,7 +208,6 @@ namespace Tensorflow.Checkpoint _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); - // skip the process of device name because lack of API. string host_device; if (tensor.IsT0) { @@ -218,6 +217,7 @@ namespace Tensorflow.Checkpoint { host_device = tensor.AsT1.device; } + host_device = saveable_object_util.set_cpu0(host_device); var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary>>()); if (!internal_dict.ContainsKey(checkpoint_key)) { @@ -329,51 +329,52 @@ namespace Tensorflow.Checkpoint { var restored_tensor_dict = saver.restore(file_prefix, options); - foreach(var pair in restored_tensor_dict) - { - var checkpoint_key = pair.Key; - var slice_and_tensor = pair.Value; - foreach(var item in slice_and_tensor) + foreach (var pair in restored_tensor_dict) { - var slice_spec = item.Key; - var tensor = item.Value; - var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; - var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); - if (!string.IsNullOrEmpty(slice_spec)) + var checkpoint_key = pair.Key; + var slice_and_tensor = pair.Value; + foreach (var item in slice_and_tensor) { - if (!internal_dict.ContainsKey(checkpoint_key)) + var slice_spec = item.Key; + var tensor = item.Value; + var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + if (!string.IsNullOrEmpty(slice_spec)) { - Dictionary dict = new(); - dict[slice_spec] = tensor; - internal_dict[checkpoint_key] = OneOf>.FromT1(dict); + if (!internal_dict.ContainsKey(checkpoint_key)) + { + Dictionary dict = new(); + dict[slice_spec] = tensor; + internal_dict[checkpoint_key] = OneOf>.FromT1(dict); + } + else + { + internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; + } } else { - internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; + internal_dict[checkpoint_key] = OneOf>.FromT0(tensor); } - } - else - { - internal_dict[checkpoint_key] = OneOf>.FromT0(tensor); - } - restore_fn_input_count[restore_fn]--; + restore_fn_input_count[restore_fn]--; - if (restore_fn_input_count[restore_fn] == 0) - { - Dictionary>> restored_tensors = new(); - foreach(var input in restore_fn_inputs[restore_fn]) + if (restore_fn_input_count[restore_fn] == 0) { - restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; - } - var ret = restore_fn.DynamicInvoke(restored_tensors); - if(ret is IDictionary) - { - var dict = (IDictionary)ret; - restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + Dictionary>> restored_tensors = new(); + foreach (var input in restore_fn_inputs[restore_fn]) + { + restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; + } + var ret = restore_fn.DynamicInvoke(restored_tensors); + if (ret is IDictionary) + { + var dict = (IDictionary)ret; + restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + } } } } - } + }); } foreach(var item in _registered_savers) diff --git a/src/TensorFlowNET.Core/Contexts/Context.Device.cs b/src/TensorFlowNET.Core/Contexts/Context.Device.cs index 32e6682e..d35d1084 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.Device.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.Device.cs @@ -111,6 +111,14 @@ namespace Tensorflow.Contexts return results.ToArray(); } + public bool is_custom_device(string device_name) + { + return false; + // TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs. + //ensure_initialized(); + //return c_api.TFE_IsCustomDevice(_handle, device_name); + } + public EagerDeviceContext device(string name) { return new EagerDeviceContext(this, name); diff --git a/src/TensorFlowNET.Core/Eager/c_api.eager.cs b/src/TensorFlowNET.Core/Eager/c_api.eager.cs index 665e537f..11de4960 100644 --- a/src/TensorFlowNET.Core/Eager/c_api.eager.cs +++ b/src/TensorFlowNET.Core/Eager/c_api.eager.cs @@ -483,5 +483,8 @@ namespace Tensorflow IntPtr[] target, int target_size, IntPtr[] sources, int source_size, IntPtr[] outputs, int output_size); + + [DllImport(TensorFlowLibName)] + public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name); } } diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index e2bdafab..587dede4 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -46,14 +46,18 @@ namespace Tensorflow { return () => { - tf.device(v.Device); - if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + return tf_with(ops.device(v.Device), _ => { - return null; - } - var x = v.read_value_no_copy(); - tf.device("/device:CPU:0"); - return array_ops.identity(x); + if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + return tf_with(ops.device("/device:CPU:0"), _ => + { + return array_ops.identity(x); + }); + }); }; } @@ -69,10 +73,12 @@ namespace Tensorflow public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) { var restored_tensor = restored_tensors[0]; - tf.device(_var_device); - restored_tensor = array_ops.identity(restored_tensor); - return resource_variable_ops.shape_safe_assign_variable_handle( + return tf_with(ops.device(_var_device), _ => + { + restored_tensor = array_ops.identity(restored_tensor); + return resource_variable_ops.shape_safe_assign_variable_handle( handle_op, _var_shape, restored_tensor); + }); } } } diff --git a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs index c4ef751b..5f198a4f 100644 --- a/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs +++ b/src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs @@ -20,6 +20,8 @@ using System.Collections.Generic; using System.Diagnostics; using System.Linq; using Tensorflow.Checkpoint; +using Tensorflow.Contexts; +using Tensorflow.Device; using Tensorflow.Operations.Activation; using Tensorflow.Train; using Tensorflow.Training; @@ -406,6 +408,17 @@ namespace Tensorflow return factory(key); } + public static string set_cpu0(string device_string) + { + if (tf.Context.is_custom_device(device_string)) + { + return device_string; + } + var parsed_device = DeviceSpec.from_string(device_string); + parsed_device = parsed_device.replace(device_type: "CPU", device_index: 0); + return parsed_device.ToString(); + } + private static bool _tensor_comes_from_variable(object v) { return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); diff --git a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs index 512e8152..dbd934af 100644 --- a/src/TensorFlowNET.Core/Variables/ResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/ResourceVariable.cs @@ -124,16 +124,29 @@ namespace Tensorflow if (_in_graph_mode) { + // TODO(Rinne): deal with initializer_op. + //if(initial_value is not null) + //{ + // tf_with(ops.name_scope("Assign"), n => + // { + // tf_with(ops.device(handle.Device), _ => + // { + + // }); + // }); + //} handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; ops.colocate_with(initializer_op); - tf.device(handle.Device); - var value = gen_resource_variable_ops.read_variable_op(handle, dtype); - resource_variable_ops._maybe_set_handle_data(dtype, handle, value); - _graph_element = gen_array_ops.identity(handle, name = "read"); - ops.add_to_collections(collections, this); - _dtype = handle.dtype; + tf_with(ops.device(handle.Device), _ => + { + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, handle, value); + _graph_element = gen_array_ops.identity(handle, name = "read"); + ops.add_to_collections(collections, this); + _dtype = handle.dtype; + }); } else { @@ -149,9 +162,11 @@ namespace Tensorflow _graph_element = null; if (!string.IsNullOrEmpty(caching_device)) { - tf.device(caching_device); - var value = gen_resource_variable_ops.read_variable_op(handle, dtype); - resource_variable_ops._maybe_set_handle_data(dtype, handle, value); + tf_with(ops.device(caching_device), _ => + { + var value = gen_resource_variable_ops.read_variable_op(handle, dtype); + resource_variable_ops._maybe_set_handle_data(dtype, handle, value); + }); } _dtype = _initial_value.dtype.as_base_dtype(); // initial_value = _in_graph_mode ? initial_value : null;