Browse Source

Fix the errors caused by branch merge.

tags/v0.100.5-BERT-load
Yaohui Liu 2 years ago
parent
commit
a59ebaeea4
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
6 changed files with 99 additions and 53 deletions
  1. +35
    -34
      src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
  2. +8
    -0
      src/TensorFlowNET.Core/Contexts/Context.Device.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  4. +16
    -10
      src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs
  5. +13
    -0
      src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs
  6. +24
    -9
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs

+ 35
- 34
src/TensorFlowNET.Core/Checkpoint/functional_saver.cs View File

@@ -208,7 +208,6 @@ namespace Tensorflow.Checkpoint
_keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn;
_restore_fn_to_keys.SetDefault(restore_fn, new List<(string, string)>()).Add((checkpoint_key, slice_spec));

// skip the process of device name because lack of API.
string host_device;
if (tensor.IsT0)
{
@@ -218,6 +217,7 @@ namespace Tensorflow.Checkpoint
{
host_device = tensor.AsT1.device;
}
host_device = saveable_object_util.set_cpu0(host_device);
var internal_dict = tensors_by_device.SetDefault(host_device, new Dictionary<string, IDictionary<string, OneOf<Tensor, SaveSpec>>>());
if (!internal_dict.ContainsKey(checkpoint_key))
{
@@ -329,51 +329,52 @@ namespace Tensorflow.Checkpoint
{
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach(var pair in restored_tensor_dict)
{
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach(var item in slice_and_tensor)
foreach (var pair in restored_tensor_dict)
{
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach (var item in slice_and_tensor)
{
if (!internal_dict.ContainsKey(checkpoint_key))
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict);
if (!internal_dict.ContainsKey(checkpoint_key))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT1(dict);
}
else
{
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor;
}
}
else
{
internal_dict[checkpoint_key].AsT1[slice_spec] = tensor;
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor);
}
}
else
{
internal_dict[checkpoint_key] = OneOf<Tensor, IDictionary<string, Tensor>>.FromT0(tensor);
}
restore_fn_input_count[restore_fn]--;
restore_fn_input_count[restore_fn]--;

if (restore_fn_input_count[restore_fn] == 0)
{
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach(var input in restore_fn_inputs[restore_fn])
if (restore_fn_input_count[restore_fn] == 0)
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if(ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
Dictionary<string, OneOf<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach (var input in restore_fn_inputs[restore_fn])
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if (ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
}
}
}
}
}
});
}

foreach(var item in _registered_savers)


+ 8
- 0
src/TensorFlowNET.Core/Contexts/Context.Device.cs View File

@@ -111,6 +111,14 @@ namespace Tensorflow.Contexts
return results.ToArray();
}

public bool is_custom_device(string device_name)
{
return false;
// TODO(Rinne): After tf2.11 TFE_IsCustomDevice has been added to C APIs.
//ensure_initialized();
//return c_api.TFE_IsCustomDevice(_handle, device_name);
}

public EagerDeviceContext device(string name)
{
return new EagerDeviceContext(this, name);


+ 3
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -483,5 +483,8 @@ namespace Tensorflow
IntPtr[] target, int target_size,
IntPtr[] sources, int source_size,
IntPtr[] outputs, int output_size);

[DllImport(TensorFlowLibName)]
public static extern bool TFE_IsCustomDevice(SafeContextHandle ctx, string device_name);
}
}

+ 16
- 10
src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs View File

@@ -46,14 +46,18 @@ namespace Tensorflow
{
return () =>
{
tf.device(v.Device);
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
return tf_with(ops.device(v.Device), _ =>
{
return null;
}
var x = v.read_value_no_copy();
tf.device("/device:CPU:0");
return array_ops.identity(x);
if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy()))
{
return null;
}
var x = v.read_value_no_copy();
return tf_with(ops.device("/device:CPU:0"), _ =>
{
return array_ops.identity(x);
});
});
};
}

@@ -69,10 +73,12 @@ namespace Tensorflow
public override Operation restore(Tensor[] restored_tensors, Shape[] restored_shapes = null)
{
var restored_tensor = restored_tensors[0];
tf.device(_var_device);
restored_tensor = array_ops.identity(restored_tensor);
return resource_variable_ops.shape_safe_assign_variable_handle(
return tf_with(ops.device(_var_device), _ =>
{
restored_tensor = array_ops.identity(restored_tensor);
return resource_variable_ops.shape_safe_assign_variable_handle(
handle_op, _var_shape, restored_tensor);
});
}
}
}

+ 13
- 0
src/TensorFlowNET.Core/Training/Saving/saveable_object_util.py.cs View File

@@ -20,6 +20,8 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Tensorflow.Checkpoint;
using Tensorflow.Contexts;
using Tensorflow.Device;
using Tensorflow.Operations.Activation;
using Tensorflow.Train;
using Tensorflow.Training;
@@ -406,6 +408,17 @@ namespace Tensorflow
return factory(key);
}

public static string set_cpu0(string device_string)
{
if (tf.Context.is_custom_device(device_string))
{
return device_string;
}
var parsed_device = DeviceSpec.from_string(device_string);
parsed_device = parsed_device.replace(device_type: "CPU", device_index: 0);
return parsed_device.ToString();
}

private static bool _tensor_comes_from_variable(object v)
{
return v is Tensor tensor && _VARIABLE_OPS.Contains(tensor.op.type);


+ 24
- 9
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -124,16 +124,29 @@ namespace Tensorflow

if (_in_graph_mode)
{
// TODO(Rinne): deal with initializer_op.
//if(initial_value is not null)
//{
// tf_with(ops.name_scope("Assign"), n =>
// {
// tf_with(ops.device(handle.Device), _ =>
// {

// });
// });
//}
handle = state_ops.variable_op_v2(_initial_value.shape, _initial_value.dtype.as_base_dtype(), name: name);
initializer_op = gen_state_ops.assign(handle, _initial_value, true).op;

ops.colocate_with(initializer_op);
tf.device(handle.Device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
_graph_element = gen_array_ops.identity(handle, name = "read");
ops.add_to_collections<IVariableV1>(collections, this);
_dtype = handle.dtype;
tf_with(ops.device(handle.Device), _ =>
{
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
_graph_element = gen_array_ops.identity(handle, name = "read");
ops.add_to_collections<IVariableV1>(collections, this);
_dtype = handle.dtype;
});
}
else
{
@@ -149,9 +162,11 @@ namespace Tensorflow
_graph_element = null;
if (!string.IsNullOrEmpty(caching_device))
{
tf.device(caching_device);
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
tf_with(ops.device(caching_device), _ =>
{
var value = gen_resource_variable_ops.read_variable_op(handle, dtype);
resource_variable_ops._maybe_set_handle_data(dtype, handle, value);
});
}
_dtype = _initial_value.dtype.as_base_dtype();
// initial_value = _in_graph_mode ? initial_value : null;


Loading…
Cancel
Save