| @@ -208,7 +208,6 @@ namespace Tensorflow.Checkpoint | |||
| _keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn; | |||
| _restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec)); | |||
| // skip the process of device name because lack of API. | |||
| string host_device; | |||
| if (tensor.IsT0) | |||
| { | |||
| @@ -218,6 +217,7 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| host_device = tensor.AsT1.device; | |||
| } | |||
| host_device = saveable_object_util.set_cpu0(host_device); | |||
| var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>()); | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| { | |||
| @@ -329,51 +329,52 @@ namespace Tensorflow.Checkpoint | |||
| { | |||
| var restored_tensor_dict = saver.restore(file_prefix, options); | |||
| foreach(var pair in restored_tensor_dict) | |||
| { | |||
| var checkpoint_key = pair.Key; | |||
| var slice_and_tensor = pair.Value; | |||
| foreach(var item in slice_and_tensor) | |||
| foreach (var pair in restored_tensor_dict) | |||
| { | |||
| var slice_spec = item.Key; | |||
| var tensor = item.Value; | |||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| var checkpoint_key = pair.Key; | |||
| var slice_and_tensor = pair.Value; | |||
| foreach (var item in slice_and_tensor) | |||
| { | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| var slice_spec = item.Key; | |||
| var tensor = item.Value; | |||
| var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; | |||
| var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>()); | |||
| if (!string.IsNullOrEmpty(slice_spec)) | |||
| { | |||
| Dictionary<string, Tensor> dict = new(); | |||
| dict[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||
| if (!internal_dict.ContainsKey(checkpoint_key)) | |||
| { | |||
| Dictionary<string, Tensor> dict = new(); | |||
| dict[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict); | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||
| } | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key].AsT1[slice_spec] = tensor; | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||
| } | |||
| } | |||
| else | |||
| { | |||
| internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor); | |||
| } | |||
| restore_fn_input_count[restore_fn]--; | |||
| restore_fn_input_count[restore_fn]--; | |||
| if (restore_fn_input_count[restore_fn] == 0) | |||
| { | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| foreach(var input in restore_fn_inputs[restore_fn]) | |||
| if (restore_fn_input_count[restore_fn] == 0) | |||
| { | |||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
| } | |||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||
| if(ret is IDictionary<string, Operation>) | |||
| { | |||
| var dict = (IDictionary<string, Operation>)ret; | |||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||
| Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new(); | |||
| foreach (var input in restore_fn_inputs[restore_fn]) | |||
| { | |||
| restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; | |||
| } | |||
| var ret = restore_fn.DynamicInvoke(restored_tensors); | |||
| if (ret is IDictionary<string, Operation>) | |||
| { | |||
| var dict = (IDictionary<string, Operation>)ret; | |||
| restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| }); | |||
| } | |||
| foreach(var item in _registered_savers) | |||
| @@ -111,6 +111,14 @@ namespace Tensorflow.Contexts | |||
| return results.ToArray(); | |||
| } | |||
| public bool is_custom_device(string device_name) | |||
| { | |||
| return false; | |||
| // TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs. | |||
| //ensure_initialized(); | |||
| //return c_api.TFE_IsCustomDevice(_handle, device_name); | |||
| } | |||
| public EagerDeviceContext device(string name) | |||
| { | |||
| return new EagerDeviceContext(this, name); | |||
| @@ -483,5 +483,8 @@ namespace Tensorflow | |||
| IntPtr[] target, int target_size, | |||
| IntPtr[] sources, int source_size, | |||
| IntPtr[] outputs, int output_size); | |||
| [DllImport(TensorFlowLibName)] | |||
| public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name); | |||
| } | |||
| } | |||
| @@ -46,14 +46,18 @@ namespace Tensorflow | |||
| { | |||
| return () => | |||
| { | |||
| tf.device(v.Device); | |||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
| return tf_with(ops.device(v.Device), _ => | |||
| { | |||
| return null; | |||
| } | |||
| var x = v.read_value_no_copy(); | |||
| tf.device("/device:CPU:0"); | |||
| return array_ops.identity(x); | |||
| if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) | |||
| { | |||
| return null; | |||
| } | |||
| var x = v.read_value_no_copy(); | |||
| return tf_with(ops.device("/device:CPU:0"), _ => | |||
| { | |||
| return array_ops.identity(x); | |||
| }); | |||
| }); | |||
| }; | |||
| } | |||
| @@ -69,10 +73,12 @@ namespace Tensorflow | |||
| public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null) | |||
| { | |||
| var restored_tensor = restored_tensors[0]; | |||
| tf.device(_var_device); | |||
| restored_tensor = array_ops.identity(restored_tensor); | |||
| return resource_variable_ops.shape_safe_assign_variable_handle( | |||
| return tf_with(ops.device(_var_device), _ => | |||
| { | |||
| restored_tensor = array_ops.identity(restored_tensor); | |||
| return resource_variable_ops.shape_safe_assign_variable_handle( | |||
| handle_op, _var_shape, restored_tensor); | |||
| }); | |||
| } | |||
| } | |||
| } | |||
| @@ -20,6 +20,8 @@ using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.Linq; | |||
| using Tensorflow.Checkpoint; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Device; | |||
| using Tensorflow.Operations.Activation; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| @@ -406,6 +408,17 @@ namespace Tensorflow | |||
| return factory(key); | |||
| } | |||
| public static string set_cpu0(string device_string) | |||
| { | |||
| if (tf.Context.is_custom_device(device_string)) | |||
| { | |||
| return device_string; | |||
| } | |||
| var parsed_device = DeviceSpec.from_string(device_string); | |||
| parsed_device = parsed_device.replace(device_type: "CPU", device_index: 0); | |||
| return parsed_device.ToString(); | |||
| } | |||
| private static bool _tensor_comes_from_variable(object v) | |||
| { | |||
| return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type); | |||
| @@ -124,16 +124,29 @@ namespace Tensorflow | |||
| if (_in_graph_mode) | |||
| { | |||
| // TODO(Rinne): deal with initializer_op. | |||
| //if(initial_value is not null) | |||
| //{ | |||
| // tf_with(ops.name_scope("Assign"), n => | |||
| // { | |||
| // tf_with(ops.device(handle.Device), _ => | |||
| // { | |||
| // }); | |||
| // }); | |||
| //} | |||
| handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name); | |||
| initializer_op = gen_state_ops.assign(handle, _initial_value, true).op; | |||
| ops.colocate_with(initializer_op); | |||
| tf.device(handle.Device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||
| ops.add_to_collections<IVariableV1>(collections, this); | |||
| _dtype = handle.dtype; | |||
| tf_with(ops.device(handle.Device), _ => | |||
| { | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| _graph_element = gen_array_ops.identity(handle, name = "read"); | |||
| ops.add_to_collections<IVariableV1>(collections, this); | |||
| _dtype = handle.dtype; | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| @@ -149,9 +162,11 @@ namespace Tensorflow | |||
| _graph_element = null; | |||
| if (!string.IsNullOrEmpty(caching_device)) | |||
| { | |||
| tf.device(caching_device); | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| tf_with(ops.device(caching_device), _ => | |||
| { | |||
| var value = gen_resource_variable_ops.read_variable_op(handle, dtype); | |||
| resource_variable_ops._maybe_set_handle_data(dtype, handle, value); | |||
| }); | |||
| } | |||
| _dtype = _initial_value.dtype.as_base_dtype(); | |||
| // initial_value = _in_graph_mode ? initial_value : null; | |||