| @@ -0,0 +1,20 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Train; | |||
| namespace Tensorflow | |||
| { | |||
| public partial class tensorflow | |||
| { | |||
| public SavedModelAPI saved_model { get; } = new SavedModelAPI(); | |||
| } | |||
| public class SavedModelAPI | |||
| { | |||
| public Trackable load(string export_dir, LoadOptions? options = null) | |||
| { | |||
| return Loader.load(export_dir, options); | |||
| } | |||
| } | |||
| } | |||
| @@ -8,6 +8,7 @@ using Tensorflow.Exceptions; | |||
| using Tensorflow.Framework; | |||
| using Tensorflow.Framework.Models; | |||
| using Tensorflow.Functions; | |||
| using Tensorflow.NumPy; | |||
| using Tensorflow.Operations; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -181,7 +182,7 @@ public class FuncGraph : Graph, IDisposable | |||
| const int _EAGER_CONST_THRESHOLD = 128; | |||
| public Tensor capture(Tensor tensor, string name = null, Shape shape = null) | |||
| { | |||
| if(tensor is EagerTensor) | |||
| if(tensor is EagerTensor or NDArray) | |||
| { | |||
| if (name == null) | |||
| name = ops.uid().ToString(); | |||
| @@ -10,4 +10,5 @@ public interface IOptimizer | |||
| void apply_gradients(IEnumerable<(Tensor, IVariableV1)> grads_and_vars, | |||
| string name = null, | |||
| bool experimental_aggregate_gradients = true); | |||
| IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null); | |||
| } | |||
| @@ -216,10 +216,12 @@ namespace Tensorflow | |||
| public virtual object get_attr(string name) | |||
| { | |||
| var buf = new Buffer(); | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, tf.Status); | |||
| tf.Status.Check(true); | |||
| Status status = new(); | |||
| c_api.TF_OperationGetAttrValueProto(_handle, name, buf, status); | |||
| status.Check(true); | |||
| var tf_buffer = c_api.TF_GetBuffer(buf); | |||
| var x = AttrValue.Parser.ParseFrom(buf.ToArray()); | |||
| var x = AttrValue.Parser.ParseFrom(tf_buffer.AsSpan<byte>()); | |||
| var oneof_value = x.ValueCase; | |||
| if (oneof_value == AttrValue.ValueOneofCase.None) | |||
| @@ -64,36 +64,68 @@ namespace Tensorflow | |||
| var num_elements = shape.size; | |||
| var tensor_dtype = tensor.Dtype.as_tf_dtype(); | |||
| T[] ExpandArrayToSize<T>(IList<T> src) | |||
| { | |||
| if(src.Count == 0) | |||
| { | |||
| return new T[0]; | |||
| } | |||
| var pad_count = num_elements - src.Count; | |||
| var pre = pad_count / 2; | |||
| var after = pad_count - pre; | |||
| var first_elem = src[0]; | |||
| var last_elem = src[src.Count - 1]; | |||
| T[] res = new T[num_elements]; | |||
| for(long i = 0; i < num_elements; i++) | |||
| { | |||
| if (i < pre) res[i] = first_elem; | |||
| else if (i >= num_elements - after) res[i] = last_elem; | |||
| else res[i] = src[(int)(i - pre)]; | |||
| } | |||
| return res; | |||
| } | |||
| if (shape.ndim > 0 && tensor.TensorContent.Length > 0) | |||
| { | |||
| return np.frombuffer(tensor.TensorContent.ToByteArray(), shape, tensor_dtype); | |||
| } | |||
| else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) | |||
| NDArray values; | |||
| if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) | |||
| { | |||
| return np.array(tensor.HalfVal.ToArray()).reshape(shape); | |||
| values = np.array(ExpandArrayToSize(tensor.HalfVal)); | |||
| } | |||
| else if (tensor.Dtype == DataType.DtFloat) | |||
| { | |||
| return np.array(tensor.FloatVal.ToArray()).reshape(shape); | |||
| values = np.array(ExpandArrayToSize(tensor.FloatVal)); | |||
| } | |||
| else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) | |||
| { | |||
| return np.array(tensor.IntVal.ToArray()).reshape(shape); | |||
| values = np.array(ExpandArrayToSize(tensor.IntVal)); | |||
| } | |||
| else if (new DataType[] { DataType.DtInt64 }.Contains(tensor.Dtype)) | |||
| { | |||
| return np.array(tensor.Int64Val.ToArray()).reshape(shape); | |||
| values = np.array(ExpandArrayToSize(tensor.Int64Val)); | |||
| } | |||
| else if (new DataType[] { DataType.DtUint64 }.Contains(tensor.Dtype)) | |||
| { | |||
| return np.array(tensor.Uint64Val.ToArray()).reshape(shape); | |||
| values = np.array(ExpandArrayToSize(tensor.Uint64Val)); | |||
| } | |||
| else if (tensor.Dtype == DataType.DtBool) | |||
| { | |||
| return np.array(tensor.BoolVal.ToArray()).reshape(shape); | |||
| values = np.array(ExpandArrayToSize(tensor.BoolVal)); | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError($"Unsupported tensor type: {tensor.Dtype}. See " + | |||
| $"https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes."); | |||
| } | |||
| if(values.size == 0) | |||
| { | |||
| return np.zeros(shape, tensor_dtype); | |||
| } | |||
| throw new NotImplementedException("MakeNdarray"); | |||
| return values.reshape(shape); | |||
| } | |||
| private static readonly TF_DataType[] quantized_types = new TF_DataType[] | |||
| @@ -1,5 +1,6 @@ | |||
| using Google.Protobuf.Collections; | |||
| using Tensorflow.Train; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Trackables; | |||
| @@ -11,12 +12,23 @@ public class TrackableConstant : Trackable | |||
| _constant = constant; | |||
| } | |||
| public static (Trackable, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||
| public static (Tensor, Action<object, object, object>) deserialize_from_proto(SavedObject object_proto, | |||
| Dictionary<string, MapField<string, AttrValue>> operation_attributes) | |||
| { | |||
| var tensor_proto = operation_attributes[object_proto.Constant.Operation]["value"].Tensor; | |||
| var ndarray = tensor_util.MakeNdarray(tensor_proto); | |||
| var imported_constant = constant_op.constant(ndarray); | |||
| return (new TrackableConstant(imported_constant), null); | |||
| Tensor imported_constant; | |||
| if (tensor_proto.Dtype == DataType.DtString) | |||
| { | |||
| imported_constant = tf_with(ops.device("CPU"), _ => | |||
| { | |||
| return constant_op.constant(ndarray); | |||
| }); | |||
| } | |||
| else | |||
| { | |||
| imported_constant = constant_op.constant(ndarray); | |||
| } | |||
| return (imported_constant, null); | |||
| } | |||
| } | |||
| @@ -46,4 +46,9 @@ public class RevivedTypes | |||
| return (null, null); | |||
| } | |||
| } | |||
| public static void RegisterRevivedTypeCreator(string identifier, ITrackableWrapper obj) | |||
| { | |||
| _registered_revived_creator[identifier] = obj; | |||
| } | |||
| } | |||
| @@ -137,7 +137,7 @@ public class SaveableView | |||
| /// </summary> | |||
| public List<int> dependency_sorted_node_ids() | |||
| { | |||
| Dictionary<int, IEnumerable<int>> dependency_map = new(); | |||
| Dictionary<int, List<int>> dependency_map = new(); | |||
| foreach (var node in _nodes) | |||
| { | |||
| var node_id = _node_ids[node]; | |||
| @@ -116,17 +116,23 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| } | |||
| Dictionary<string, ConcreteFunction> loaded_gradients = new(); | |||
| foreach (var fdef in _sort_function_defs(library, function_deps)) | |||
| // Debug(Rinne) | |||
| var temp = _sort_function_defs(library, function_deps); | |||
| int i = 0; | |||
| foreach (var fdef in temp) | |||
| { | |||
| i++; | |||
| var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | |||
| object structured_input_signature = null; | |||
| object structured_outputs = null; | |||
| if (saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | |||
| { | |||
| var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
| structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); | |||
| structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); | |||
| // TODO(Rinne): deal with structured_input_signature and structured_outputs. | |||
| //var proto = saved_object_graph.ConcreteFunctions[orig_name]; | |||
| //structured_input_signature = nested_structure_coder.decode_proto(proto.CanonicalizedInputSignature); | |||
| //structured_outputs = nested_structure_coder.decode_proto(proto.OutputSignature); | |||
| } | |||
| graph.as_default(); | |||
| @@ -234,27 +240,41 @@ namespace Tensorflow.Training.Saving.SavedModel | |||
| private static void _restore_gradient_functions(FuncGraph func_graph, Dictionary<string, ConcreteFunction> renamed_functions, Dictionary<string, ConcreteFunction> loaded_gradients) | |||
| { | |||
| foreach(var op in func_graph.get_operations()) | |||
| if(loaded_gradients is null || loaded_gradients.Count == 0) | |||
| { | |||
| if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
| { | |||
| var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
| op.op._gradient_function = function._get_gradient_function(); | |||
| } | |||
| string gradient_op_type = null; | |||
| try | |||
| { | |||
| gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||
| } | |||
| catch(InvalidArgumentError) | |||
| foreach (var op in func_graph.get_operations()) | |||
| { | |||
| continue; | |||
| if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
| { | |||
| var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
| op.op._gradient_function = function._get_gradient_function(); | |||
| } | |||
| } | |||
| if (loaded_gradients.ContainsKey(gradient_op_type)) | |||
| } | |||
| else | |||
| { | |||
| foreach (var op in func_graph.get_operations()) | |||
| { | |||
| var grad_fn = loaded_gradients[gradient_op_type]; | |||
| grad_fn.NumPositionArgs = op.op.inputs.Length; | |||
| grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||
| if (op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | |||
| { | |||
| var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||
| op.op._gradient_function = function._get_gradient_function(); | |||
| } | |||
| string gradient_op_type = null; | |||
| try | |||
| { | |||
| gradient_op_type = op.op.get_attr("_gradient_op_type") as string; | |||
| } | |||
| catch (InvalidArgumentError) | |||
| { | |||
| continue; | |||
| } | |||
| if (loaded_gradients.ContainsKey(gradient_op_type)) | |||
| { | |||
| var grad_fn = loaded_gradients[gradient_op_type]; | |||
| grad_fn.NumPositionArgs = op.op.inputs.Length; | |||
| grad_fn.ArgKeywords = op.op.inputs._inputs.Select(x => x.name); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -15,6 +15,7 @@ using Tensorflow.Functions; | |||
| using Tensorflow.Training.Saving.SavedModel; | |||
| using Tensorflow.Trackables; | |||
| using OneOf; | |||
| using Tensorflow.Keras.Engine; | |||
| namespace Tensorflow | |||
| { | |||
| @@ -34,7 +35,7 @@ namespace Tensorflow | |||
| private List<int>? _filtered_nodes; | |||
| private List<int> _ordered_node_ids; | |||
| private Dictionary<int, (Trackable, Action<object, object, object>)> _loaded_nodes; | |||
| private List<Trackable> _nodes; | |||
| private List<object> _nodes; | |||
| private Dictionary<int, Action<object, object, object>> _node_setters; | |||
| private Dictionary<string, ConcreteFunction> _concrete_functions; | |||
| private HashSet<string> _restored_concrete_functions; | |||
| @@ -213,7 +214,13 @@ namespace Tensorflow | |||
| continue; | |||
| } | |||
| var proto = _proto.Nodes[node_id]; | |||
| foreach(var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
| if(node_id == 10522) | |||
| { | |||
| // Debug(Rinne) | |||
| Console.WriteLine(); | |||
| } | |||
| var temp = _get_node_dependencies(proto); | |||
| foreach (var dep in _get_node_dependencies(proto).Values.Distinct()) | |||
| { | |||
| deps.Add(dep); | |||
| if(_filtered_nodes is not null && !_filtered_nodes.Contains(dep)) | |||
| @@ -232,7 +239,7 @@ namespace Tensorflow | |||
| // The optimizer and original variable must be created before the slot | |||
| // variable, since the slot variable is generated using the Optimizer's | |||
| // add_slot API. | |||
| var slot_deps = dependency_map[slot_variable_node_id]; | |||
| var slot_deps = dependency_map.SetDefault(slot_variable_node_id, new List<int>()); | |||
| slot_deps.Add(node_id); | |||
| slot_deps.Add(slot_variable_proto.OriginalVariableNodeId); | |||
| @@ -245,7 +252,12 @@ namespace Tensorflow | |||
| } | |||
| try | |||
| { | |||
| return TrackableUtils.order_by_dependency(dependency_map.ToDictionary(x => x.Key, x => x.Value as IEnumerable<int>)); | |||
| int total = 0; | |||
| foreach(var v in dependency_map.Values) | |||
| { | |||
| total += v.Count; | |||
| } | |||
| return TrackableUtils.order_by_dependency(dependency_map); | |||
| } | |||
| catch (TrackableUtils.CyclicDependencyError ex) | |||
| { | |||
| @@ -339,9 +351,20 @@ namespace Tensorflow | |||
| var saveable_object_proto = item.Value; | |||
| var save_fn_id = saveable_object_proto.SaveFunction; | |||
| var restore_fn_id = saveable_object_proto.RestoreFunction; | |||
| saveable_fn_by_name[name] = (get(save_fn_id), get(restore_fn_id)); | |||
| saveable_fn_by_name[name] = ((Trackable)get(save_fn_id), (Trackable)get(restore_fn_id)); | |||
| } | |||
| var saveable_objects = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||
| if (saveable_objects is not null && saveable_objects.Count > 0) | |||
| { | |||
| if(node is Trackable trackable) | |||
| { | |||
| trackable.SelfSaveableObjectFactories = saveable_objects; | |||
| } | |||
| else | |||
| { | |||
| throw new TypeError(); | |||
| } | |||
| } | |||
| node.SelfSaveableObjectFactories = saveable_object_util.recreate_saveable_objects(saveable_fn_by_name, null); | |||
| } | |||
| } | |||
| } | |||
| @@ -379,12 +402,12 @@ namespace Tensorflow | |||
| { | |||
| // Use the public Optimizer interface when creating slot variables. | |||
| var (optimizer_node_id, slot_variable_proto) = slot_variable_node_ids[node_id]; | |||
| var optimizer_object = nodes[optimizer_node_id]; | |||
| var optimizer_object = nodes[optimizer_node_id] as IOptimizer; | |||
| var optimizer_variable = nodes[slot_variable_proto.OriginalVariableNodeId]; | |||
| // TODO(Rinne): implement it. | |||
| throw new NotImplementedException("The model loading of SavedModel still has some incompleted part." + | |||
| " Please submit an issue to https://github.com/SciSharp/TensorFlow.NET/issues."); | |||
| var slot_variable = optimizer_object.add_slot(optimizer_variable as IVariableV1, slot_variable_proto.SlotName); | |||
| nodes[slot_variable_proto.SlotVariableNodeId] = slot_variable as Trackable; | |||
| node_setters[slot_variable_proto.SlotVariableNodeId] = setattr; | |||
| } | |||
| else | |||
| { | |||
| @@ -398,7 +421,7 @@ namespace Tensorflow | |||
| { | |||
| nodes[0] = _recreate_base_user_object().Item1; | |||
| } | |||
| _nodes = new List<Trackable>(); | |||
| _nodes = new List<object>(); | |||
| for(int i = 0; i < _proto.Nodes.Count; i++) | |||
| { | |||
| _nodes.Add(nodes[i]); | |||
| @@ -412,7 +435,7 @@ namespace Tensorflow | |||
| private void _restore_checkpoint() | |||
| { | |||
| var variables_path = SavedModelUtils.get_variables_path(_export_dir); | |||
| var saver = new TrackableSaver(new ObjectGraphView(get(0))); | |||
| var saver = new TrackableSaver(new ObjectGraphView((Trackable)get(0))); | |||
| tf_with(ops.device("CPU"), _ => | |||
| { | |||
| saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); | |||
| @@ -467,7 +490,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, Trackable> nodes) | |||
| private void _setup_function_captures(string concrete_function_name, IDictionary<OneOf<string, int>, object> nodes) | |||
| { | |||
| if (_restored_concrete_functions.Contains(concrete_function_name)) | |||
| { | |||
| @@ -485,12 +508,12 @@ namespace Tensorflow | |||
| // TODO: implement it with concrete functions. | |||
| } | |||
| public Trackable get(int node_id) | |||
| public object get(int node_id) | |||
| { | |||
| return _nodes[node_id]; | |||
| } | |||
| public Trackable get(string node_id) | |||
| public object get(string node_id) | |||
| { | |||
| return get(_node_path_to_id[node_id]); | |||
| } | |||
| @@ -512,9 +535,9 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private (Dictionary<int, Trackable>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||
| private (Dictionary<int, object>, Dictionary<int, Action<object, object, object>>) _initialize_loaded_nodes() | |||
| { | |||
| Dictionary<int, Trackable> nodes = new(); | |||
| Dictionary<int, object> nodes = new(); | |||
| Dictionary<int, Action<object, object, object>> node_setters = new(); | |||
| foreach(var item in _loaded_nodes) | |||
| { | |||
| @@ -534,10 +557,10 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| private (Trackable, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, Trackable> nodes) | |||
| private (object, Action<object, object, object>) _recreate(SavedObject proto, int node_id, IDictionary<int, object> nodes) | |||
| { | |||
| // skip the registered classes. | |||
| Dictionary<OneOf<string, int>, Trackable> dependencies = new(); | |||
| Dictionary<OneOf<string, int>, object> dependencies = new(); | |||
| foreach(var item in _get_node_dependencies(proto)) | |||
| { | |||
| dependencies[item.Key] = nodes[item.Value]; | |||
| @@ -558,7 +581,7 @@ namespace Tensorflow | |||
| /// <param name="proto"></param> | |||
| /// <param name="node_id"></param> | |||
| /// <param name="dependencies"></param> | |||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
| private (Trackable, Action<object, object, object>) _recreate_default(SavedObject proto, int node_id, IDictionary<OneOf<string, int>, object> dependencies) | |||
| { | |||
| return proto.KindCase switch | |||
| { | |||
| @@ -626,7 +649,7 @@ namespace Tensorflow | |||
| } | |||
| private (Function, Action<object, object, object>) _recreate_function(SavedFunction proto, | |||
| IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
| IDictionary<OneOf<string, int>, object> dependencies) | |||
| { | |||
| var fn = function_deserialization.recreate_function(proto, _concrete_functions); | |||
| foreach (var name in proto.ConcreteFunctions) | |||
| @@ -637,7 +660,7 @@ namespace Tensorflow | |||
| } | |||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | |||
| IDictionary<OneOf<string, int>, Trackable> dependencies) | |||
| IDictionary<OneOf<string, int>, object> dependencies) | |||
| { | |||
| var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | |||
| _setup_function_captures(proto.ConcreteFunctionName, dependencies); | |||
| @@ -78,7 +78,7 @@ namespace Tensorflow | |||
| tf_with(ops.init_scope(), x => | |||
| { | |||
| loader = new Loader(object_graph_proto, saved_model_proto, export_dir, ckpt_options, options, filters); | |||
| root = loader.get(0); | |||
| root = (Trackable)loader.get(0); | |||
| // skip the assignment of `graph_debug_info`. | |||
| }); | |||
| // skip the assignment of `tensorflow_version` | |||
| @@ -99,7 +99,7 @@ namespace Tensorflow | |||
| } | |||
| if(filters != null && filters.Count > 0) | |||
| { | |||
| return filters.Keys.ToDictionary(x => x, x => loader.get(x)); | |||
| return filters.Keys.ToDictionary(x => x, x => (Trackable)loader.get(x)); | |||
| } | |||
| else | |||
| { | |||
| @@ -52,7 +52,7 @@ public static class TrackableUtils | |||
| /// </summary> | |||
| /// <param name="dependency_map"></param> | |||
| /// <exception cref="ValueError"></exception> | |||
| public static List<int> order_by_dependency(IDictionary<int, IEnumerable<int>> dependency_map) | |||
| public static List<int> order_by_dependency(IDictionary<int, List<int>> dependency_map) | |||
| { | |||
| Dictionary<int, HashSet<int>> reverse_dependency_map = new(); | |||
| foreach (var pair in dependency_map) | |||
| @@ -102,7 +102,7 @@ public static class TrackableUtils | |||
| edges.Remove(x); | |||
| if (edges.Count == 0) | |||
| { | |||
| to_visit.Enqueue(dep); | |||
| to_visit.Enqueue(dep); | |||
| if (!reverse_dependency_map.Remove(dep)) | |||
| { | |||
| throw new KeyError($"Cannot find the key {dep} in reverse_dependency_map"); | |||
| @@ -333,5 +333,23 @@ namespace Tensorflow | |||
| }); | |||
| return array_ops.identity(value); | |||
| } | |||
| //public static Tensor operator +(BaseResourceVariable x, int y) => x.value() + y; | |||
| //public static Tensor operator +(BaseResourceVariable x, float y) => x.value() + y; | |||
| //public static Tensor operator +(BaseResourceVariable x, double y) => x.value() + y; | |||
| //public static Tensor operator +(BaseResourceVariable x, BaseResourceVariable y) => x.value() + y.value(); | |||
| //public static Tensor operator -(BaseResourceVariable x, int y) => x.value() - y; | |||
| //public static Tensor operator -(BaseResourceVariable x, float y) => x.value() - y; | |||
| //public static Tensor operator -(BaseResourceVariable x, double y) => x.value() - y; | |||
| //public static Tensor operator -(BaseResourceVariable x, Tensor y) => x.value() - y; | |||
| //public static Tensor operator -(BaseResourceVariable x, BaseResourceVariable y) => x.value() - y.value(); | |||
| //public static Tensor operator *(BaseResourceVariable x, BaseResourceVariable y) => x.value() * y.value(); | |||
| //public static Tensor operator *(BaseResourceVariable x, Tensor y) => x.value() * y; | |||
| //public static Tensor operator *(BaseResourceVariable x, NDArray y) => x.value() * y; | |||
| //public static Tensor operator <(BaseResourceVariable x, Tensor y) => x.value() < y; | |||
| //public static Tensor operator >(BaseResourceVariable x, Tensor y) => x.value() > y; | |||
| } | |||
| } | |||
| @@ -1,19 +1,6 @@ | |||
| /***************************************************************************** | |||
| Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. | |||
| Licensed under the Apache License, Version 2.0 (the "License"); | |||
| you may not use this file except in compliance with the License. | |||
| You may obtain a copy of the License at | |||
| http://www.apache.org/licenses/LICENSE-2.0 | |||
| Unless required by applicable law or agreed to in writing, software | |||
| distributed under the License is distributed on an "AS IS" BASIS, | |||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| See the License for the specific language governing permissions and | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.NumPy; | |||
| namespace Tensorflow | |||
| @@ -169,6 +169,12 @@ namespace Tensorflow.Keras | |||
| _GRAPH_LEARNING_PHASES[tf.get_default_graph()] = (GraphLearningPhase)((value) ? 1 : 0); | |||
| } | |||
| public void set_value(IVariableV1 x, object value) | |||
| { | |||
| // TODO(Rinne): check the implementation. | |||
| x.assign(value); | |||
| } | |||
| public void batch_set_value(List<(IVariableV1, NDArray)> tuples) | |||
| { | |||
| if (ops.executing_eagerly_outside_functions()) | |||
| @@ -36,6 +36,11 @@ namespace Tensorflow.Keras | |||
| } | |||
| } | |||
| static KerasInterface() | |||
| { | |||
| RevivedTypes.RegisterRevivedTypeCreator("optimizer", new RestoredOptimizer()); | |||
| } | |||
| public KerasDataset datasets { get; } = new KerasDataset(); | |||
| public IInitializersApi initializers { get; } = new InitializersApi(); | |||
| public Regularizers regularizers { get; } = new Regularizers(); | |||
| @@ -14,11 +14,11 @@ namespace Tensorflow.Keras.Optimizers | |||
| protected bool _hypers_created; | |||
| protected virtual string _name { get; } | |||
| IVariableV1 _iterations; | |||
| protected IVariableV1 _iterations; | |||
| protected ResourceVariable iterations => _iterations as ResourceVariable; | |||
| List<IVariableV1> _weights; | |||
| Dictionary<string, float> _hyper; | |||
| Dictionary<string, IVariableV1> _hyper_variables; | |||
| protected Dictionary<string, float> _hyper; | |||
| protected Dictionary<string, IVariableV1> _hyper_variables; | |||
| protected bool _momentum; | |||
| protected float _initial_decay = 0.0f; | |||
| protected bool _use_locking = true; | |||
| @@ -224,7 +224,7 @@ namespace Tensorflow.Keras.Optimizers | |||
| } | |||
| } | |||
| protected IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||
| public IVariableV1 add_slot(IVariableV1 var, string slot_name, IInitializer initializer = null) | |||
| { | |||
| if (initializer == null) | |||
| initializer = tf.zeros_initializer; | |||
| @@ -0,0 +1,63 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using Tensorflow.Keras.Saving; | |||
| using Tensorflow.Train; | |||
| using Tensorflow.Training; | |||
| namespace Tensorflow.Keras.Optimizers | |||
| { | |||
| public class RestoredOptimizer: OptimizerV2, ITrackableWrapper, IKerasConfig | |||
| { | |||
| public String Identifier { get; } = "optimizer"; | |||
| public int Version { get; } = 2; | |||
| public int MinConsumerVersion { get; } = 1; | |||
| public int MinProducerVersion { get; } = 1; | |||
| public RestoredOptimizer(): base(new ArgsDefinition.OptimizerV2Args() { Name = "RestoredOptimizer" }) | |||
| { | |||
| _hypers_created = true; | |||
| } | |||
| public IKerasConfig get_config() | |||
| { | |||
| throw new NotImplementedException("Restoring functional Optimizers from SavedModels is not currently " + | |||
| "supported. Please file a feature request if this limitation bothers you."); | |||
| } | |||
| public void SetValue(object name, object value) | |||
| { | |||
| if(name is not String str) | |||
| { | |||
| throw new TypeError($"The name of value to set must be string, but got {name.GetType()}"); | |||
| } | |||
| if(value is Trackable trackable) | |||
| { | |||
| _track_trackable(trackable, str, overwrite: true); | |||
| } | |||
| if(value is IVariableV1 resource_variable) | |||
| { | |||
| if (!_hyper_variables.ContainsKey(str)) | |||
| { | |||
| _hyper_variables[str] = resource_variable; | |||
| } | |||
| else | |||
| { | |||
| keras.backend.set_value(resource_variable, value); | |||
| } | |||
| } | |||
| else if (value is float f) | |||
| { | |||
| _hyper[str] = f; | |||
| } | |||
| else | |||
| { | |||
| throw new NotImplementedException(); | |||
| } | |||
| } | |||
| public Trackable FromProto(SavedUserObject proto) | |||
| { | |||
| return new RestoredOptimizer(); | |||
| } | |||
| } | |||
| } | |||
| @@ -2,6 +2,7 @@ | |||
| using System; | |||
| using System.Linq; | |||
| using Tensorflow; | |||
| using Tensorflow.Keras.Engine; | |||
| using Tensorflow.Keras.Optimizers; | |||
| using Tensorflow.Keras.UnitTest.Helpers; | |||
| using Tensorflow.NumPy; | |||
| @@ -103,4 +104,13 @@ public class SequentialModelLoad | |||
| classify_model.fit(x, y, batch_size: 4); | |||
| } | |||
| [Ignore] | |||
| [TestMethod] | |||
| public void TestModelBeforeTF2_5() | |||
| { | |||
| var a = keras.layers; | |||
| var model = tf.saved_model.load(@"D:\development\temp\saved_model") as Model; | |||
| model.summary(); | |||
| } | |||
| } | |||