| @@ -208,7 +208,6 @@ namespace Tensorflow.Checkpoint | |||||
| _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | ||||
| _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | ||||
| // skip the process of device name because lack of API. | |||||
| string host_device; | string host_device; | ||||
| if (tensor.IsT0) | if (tensor.IsT0) | ||||
| { | { | ||||
| @@ -218,6 +217,7 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| host_device = tensor.AsT1.device; | host_device = tensor.AsT1.device; | ||||
| } | } | ||||
| host_device = saveable_object_util.set_cpu0(host_device); | |||||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | ||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | if (!internal_dict.ContainsKey(checkpoint_key)) | ||||
| { | { | ||||
| @@ -329,51 +329,52 @@ namespace Tensorflow.Checkpoint | |||||
| { | { | ||||
| var restored_tensor_dict = saver.restore(file_prefix, options); | var restored_tensor_dict = saver.restore(file_prefix, options); | ||||
| foreach(var pair in restored_tensor_dict) | |||||
| { | |||||
| var checkpoint_key = pair.Key; | |||||
| var slice_and_tensor = pair.Value; | |||||
| foreach(var item in slice_and_tensor) | |||||
| foreach (var pair in restored_tensor_dict) | |||||
| { | { | ||||
| var slice_spec = item.Key; | |||||
| var tensor = item.Value; | |||||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||||
| if (!string.IsNullOrEmpty(slice_spec)) | |||||
| var checkpoint_key = pair.Key; | |||||
| var slice_and_tensor = pair.Value; | |||||
| foreach (var item in slice_and_tensor) | |||||
| { | { | ||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
| var slice_spec = item.Key; | |||||
| var tensor = item.Value; | |||||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||||
| if (!string.IsNullOrEmpty(slice_spec)) | |||||
| { | { | ||||
| Dictionary<string, Tensor> dict = new(); | |||||
| dict[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||||
| { | |||||
| Dictionary<string, Tensor> dict = new(); | |||||
| dict[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||||
| } | |||||
| else | |||||
| { | |||||
| internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||||
| } | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||||
| } | } | ||||
| } | |||||
| else | |||||
| { | |||||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||||
| } | |||||
| restore_fn_input_count[restore_fn]--; | |||||
| restore_fn_input_count[restore_fn]--; | |||||
| if (restore_fn_input_count[restore_fn] == 0) | |||||
| { | |||||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
| foreach(var input in restore_fn_inputs[restore_fn]) | |||||
| if (restore_fn_input_count[restore_fn] == 0) | |||||
| { | { | ||||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
| } | |||||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
| if(ret is IDictionary<string, Operation>) | |||||
| { | |||||
| var dict = (IDictionary<string, Operation>)ret; | |||||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||||
| foreach (var input in restore_fn_inputs[restore_fn]) | |||||
| { | |||||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||||
| } | |||||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||||
| if (ret is IDictionary<string, Operation>) | |||||
| { | |||||
| var dict = (IDictionary<string, Operation>)ret; | |||||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| }); | |||||
| } | } | ||||
| foreach(var item in _registered_savers) | foreach(var item in _registered_savers) | ||||
| @@ -111,6 +111,14 @@ namespace Tensorflow.Contexts | |||||
| return results.ToArray(); | return results.ToArray(); | ||||
| } | } | ||||
| public bool is_custom_device(string device_name) | |||||
| { | |||||
| return false; | |||||
| // TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs. | |||||
| //ensure_initialized(); | |||||
| //return c_api.TFE_IsCustomDevice(_handle, device_name); | |||||
| } | |||||
| public EagerDeviceContext device(string name) | public EagerDeviceContext device(string name) | ||||
| { | { | ||||
| return new EagerDeviceContext(this, name); | return new EagerDeviceContext(this, name); | ||||
| @@ -483,5 +483,8 @@ namespace Tensorflow | |||||
| IntPtr[] target, int target_size, | IntPtr[] target, int target_size, | ||||
| IntPtr[] sources, int source_size, | IntPtr[] sources, int source_size, | ||||
| IntPtr[] outputs, int output_size); | IntPtr[] outputs, int output_size); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -46,14 +46,18 @@ namespace Tensorflow | |||||
| { | { | ||||
| return () => | return () => | ||||
| { | { | ||||
| tf.device(v.Device); | |||||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
| return tf_with(ops.device(v.Device), _ => | |||||
| { | { | ||||
| return null; | |||||
| } | |||||
| var x = v.read_value_no_copy(); | |||||
| tf.device("/device:CPU:0"); | |||||
| return array_ops.identity(x); | |||||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||||
| { | |||||
| return null; | |||||
| } | |||||
| var x = v.read_value_no_copy(); | |||||
| return tf_with(ops.device("/device:CPU:0"), _ => | |||||
| { | |||||
| return array_ops.identity(x); | |||||
| }); | |||||
| }); | |||||
| }; | }; | ||||
| } | } | ||||
| @@ -69,10 +73,12 @@ namespace Tensorflow | |||||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | ||||
| { | { | ||||
| var restored_tensor = restored_tensors[0]; | var restored_tensor = restored_tensors[0]; | ||||
| tf.device(_var_device); | |||||
| restored_tensor = array_ops.identity(restored_tensor); | |||||
| return resource_variable_ops.shape_safe_assign_variable_handle( | |||||
| return tf_with(ops.device(_var_device), _ => | |||||
| { | |||||
| restored_tensor = array_ops.identity(restored_tensor); | |||||
| return resource_variable_ops.shape_safe_assign_variable_handle( | |||||
| handle_op, _var_shape, restored_tensor); | handle_op, _var_shape, restored_tensor); | ||||
| }); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -20,6 +20,8 @@ using System.Collections.Generic; | |||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Checkpoint; | using Tensorflow.Checkpoint; | ||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Device; | |||||
| using Tensorflow.Operations.Activation; | using Tensorflow.Operations.Activation; | ||||
| using Tensorflow.Train; | using Tensorflow.Train; | ||||
| using Tensorflow.Training; | using Tensorflow.Training; | ||||
| @@ -406,6 +408,17 @@ namespace Tensorflow | |||||
| return factory(key); | return factory(key); | ||||
| } | } | ||||
| public static string set_cpu0(string device_string) | |||||
| { | |||||
| if (tf.Context.is_custom_device(device_string)) | |||||
| { | |||||
| return device_string; | |||||
| } | |||||
| var parsed_device = DeviceSpec.from_string(device_string); | |||||
| parsed_device = parsed_device.replace(device_type: "CPU", device_index: 0); | |||||
| return parsed_device.ToString(); | |||||
| } | |||||
| private static bool _tensor_comes_from_variable(object v) | private static bool _tensor_comes_from_variable(object v) | ||||
| { | { | ||||
| return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); | return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); | ||||
| @@ -124,16 +124,29 @@ namespace Tensorflow | |||||
| if (_in_graph_mode) | if (_in_graph_mode) | ||||
| { | { | ||||
| // TODO(Rinne): deal with initializer_op. | |||||
| //if(initial_value is not null) | |||||
| //{ | |||||
| // tf_with(ops.name_scope("Assign"), n => | |||||
| // { | |||||
| // tf_with(ops.device(handle.Device), _ => | |||||
| // { | |||||
| // }); | |||||
| // }); | |||||
| //} | |||||
| handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | ||||
| initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | ||||
| ops.colocate_with(initializer_op); | ops.colocate_with(initializer_op); | ||||
| tf.device(handle.Device); | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||||
| ops.add_to_collections<IVariableV1>(collections, this); | |||||
| _dtype = handle.dtype; | |||||
| tf_with(ops.device(handle.Device), _ => | |||||
| { | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||||
| ops.add_to_collections<IVariableV1>(collections, this); | |||||
| _dtype = handle.dtype; | |||||
| }); | |||||
| } | } | ||||
| else | else | ||||
| { | { | ||||
| @@ -149,9 +162,11 @@ namespace Tensorflow | |||||
| _graph_element = null; | _graph_element = null; | ||||
| if (!string.IsNullOrEmpty(caching_device)) | if (!string.IsNullOrEmpty(caching_device)) | ||||
| { | { | ||||
| tf.device(caching_device); | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| tf_with(ops.device(caching_device), _ => | |||||
| { | |||||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||||
| }); | |||||
| } | } | ||||
| _dtype = _initial_value.dtype.as_base_dtype(); | _dtype = _initial_value.dtype.as_base_dtype(); | ||||
| // initial_value = _in_graph_mode ? initial_value : null; | // initial_value = _in_graph_mode ? initial_value : null; | ||||