| @@ -28,9 +28,9 @@ namespace Tensorflow | |||
| /// </summary> | |||
| /// <param name="input_ops">The data input ops for an op to be created.</param> | |||
| /// <returns>A list of control inputs for the op to be created.</returns> | |||
| private Operation[] _control_dependencies_for_inputs(Operation[] input_ops) | |||
| private ITensorOrOperation[] _control_dependencies_for_inputs(ITensorOrOperation[] input_ops) | |||
| { | |||
| Operation[] ret = new Operation[0]; | |||
| var ret = new ITensorOrOperation[0]; | |||
| foreach(var controller in _control_dependencies_stack) | |||
| { | |||
| @@ -54,12 +54,12 @@ namespace Tensorflow | |||
| return ret; | |||
| } | |||
| public _ControlDependenciesController control_dependencies(Operation[] control_inputs) | |||
| public _ControlDependenciesController control_dependencies(ITensorOrOperation[] control_inputs) | |||
| { | |||
| if (control_inputs == null) | |||
| return new _ControlDependenciesController(this, null); | |||
| var control_ops = new List<Operation>(); | |||
| var control_ops = new List<ITensorOrOperation>(); | |||
| foreach (var c in control_inputs) | |||
| { | |||
| control_ops.Add(c); | |||
| @@ -298,6 +298,11 @@ namespace Tensorflow | |||
| return _nodes_by_name.Values.Select(x => x).ToArray(); | |||
| } | |||
| public string[] get_all_collection_keys() | |||
| { | |||
| return _collections.Keys.Where(x => !x.StartsWith("__")).ToArray(); | |||
| } | |||
| public object get_collection(string name, string scope = "") | |||
| { | |||
| return _collections.ContainsKey(name) ? _collections[name] : null; | |||
| @@ -11,20 +11,20 @@ namespace Tensorflow | |||
| public class _ControlDependenciesController : IPython | |||
| { | |||
| private Graph _graph; | |||
| private List<Operation> _control_inputs_val; | |||
| private List<Operation> _seen_nodes; | |||
| private List<ITensorOrOperation> _control_inputs_val; | |||
| private List<ITensorOrOperation> _seen_nodes; | |||
| private Queue<_ControlDependenciesController> _old_stack; | |||
| private bool _new_stack; | |||
| private Context _old_control_flow_context; | |||
| public Operation[] control_inputs => _control_inputs_val.ToArray(); | |||
| public ITensorOrOperation[] control_inputs => _control_inputs_val.ToArray(); | |||
| public _ControlDependenciesController(Graph graph, List<Operation> control_inputs) | |||
| public _ControlDependenciesController(Graph graph, List<ITensorOrOperation> control_inputs) | |||
| { | |||
| _graph = graph; | |||
| if (control_inputs == null) | |||
| { | |||
| _control_inputs_val = new List<Operation>(); | |||
| _control_inputs_val = new List<ITensorOrOperation>(); | |||
| _new_stack = true; | |||
| } | |||
| else | |||
| @@ -33,15 +33,15 @@ namespace Tensorflow | |||
| _new_stack = false; | |||
| } | |||
| _seen_nodes = new List<Operation>(); | |||
| _seen_nodes = new List<ITensorOrOperation>(); | |||
| } | |||
| public void add_op(Operation op) | |||
| public void add_op(ITensorOrOperation op) | |||
| { | |||
| _seen_nodes.Add(op); | |||
| } | |||
| public bool op_in_group(Operation op) | |||
| public bool op_in_group(ITensorOrOperation op) | |||
| { | |||
| return _seen_nodes.Contains(op); | |||
| } | |||
| @@ -11,5 +11,6 @@ namespace Tensorflow | |||
| public interface ITensorOrOperation | |||
| { | |||
| string Device { get; } | |||
| Operation op { get; } | |||
| } | |||
| } | |||
| @@ -107,7 +107,9 @@ namespace Tensorflow | |||
| values = ops.internal_convert_to_tensor(values, | |||
| name: input_name, | |||
| as_ref: input_arg.IsRef); | |||
| dtype: dtype, | |||
| as_ref: input_arg.IsRef, | |||
| preferred_dtype: default_dtype); | |||
| //if (!String.IsNullOrEmpty(input_arg.TypeAttr)) | |||
| //attrs[input_arg.TypeAttr] = values.dtype; | |||
| @@ -163,14 +165,20 @@ namespace Tensorflow | |||
| foreach (var arg in op_def.OutputArg) | |||
| { | |||
| types = new List<TF_DataType>(); | |||
| if (!string.IsNullOrEmpty(arg.NumberAttr)) | |||
| { | |||
| } | |||
| else if (!string.IsNullOrEmpty(arg.TypeAttr)) | |||
| { | |||
| output_types.Add((TF_DataType)attr_protos[arg.TypeAttr].Type); | |||
| types = new List<TF_DataType>() { (TF_DataType)attr_protos[arg.TypeAttr].Type }; | |||
| } | |||
| if (arg.IsRef) | |||
| types = types.Select(x => x.as_ref()).ToList(); | |||
| output_types.AddRange(types); | |||
| } | |||
| // Add Op to graph | |||
| @@ -16,6 +16,7 @@ namespace Tensorflow | |||
| private int _id_value; | |||
| public string type => OpType; | |||
| public Operation op => this; | |||
| private Status status = new Status(); | |||
| @@ -75,7 +76,7 @@ namespace Tensorflow | |||
| /// </param> | |||
| /// <param name="original_op"></param> | |||
| /// <param name="op_def"></param> | |||
| public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, Operation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
| public Operation(NodeDef node_def, Graph g, Tensor[] inputs = null, TF_DataType[] output_types = null, ITensorOrOperation[] control_inputs = null, TF_DataType[] input_types = null, string original_op = "", OpDef op_def = null) | |||
| { | |||
| Graph = g; | |||
| @@ -120,6 +121,11 @@ namespace Tensorflow | |||
| _control_flow_post_processing(); | |||
| } | |||
| public void run(FeedItem[] feed_dict = null, Session session = null) | |||
| { | |||
| ops._run_using_default_session(this, feed_dict, Graph, session); | |||
| } | |||
| private object[] _reconstruct_sequence_inputs(OpDef op_def, Tensor[] inputs, MapField<string, AttrValue> attrs) | |||
| { | |||
| var grouped_inputs = new List<object>(); | |||
| @@ -204,7 +210,7 @@ namespace Tensorflow | |||
| public override string ToString() | |||
| { | |||
| return _handle == IntPtr.Zero ? "Undefined" : $"'{Name}' type={OpType}"; | |||
| return _handle == IntPtr.Zero ? "tf.Operation Undefined" : $"tf.Operation '{Name}' type={OpType}"; | |||
| } | |||
| public static implicit operator Operation(IntPtr handle) => new Operation(handle); | |||
| @@ -28,10 +28,7 @@ namespace Tensorflow | |||
| { | |||
| var dev = ops_on_device.Keys.First(); | |||
| var deps = ops_on_device.Values.First(); | |||
| if (typeof(T).Name == "Operation") | |||
| return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name); | |||
| else | |||
| throw new NotImplementedException("control_flow_ops.group"); | |||
| return _GroupControlDeps(dev, deps.Select(x => x.op).ToArray(), name); | |||
| } | |||
| // 2-level tree. The root node is the returned NoOp node. | |||
| @@ -35,17 +35,8 @@ namespace Tensorflow | |||
| c_api.TF_DeleteSessionOptions(opts); | |||
| } | |||
| public virtual NDArray run(RefVariable fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray run(Tensor fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| public virtual NDArray run(Operation fetches, FeedItem[] feed_dict = null) | |||
| public virtual NDArray run<T>(T fetches, FeedItem[] feed_dict = null) | |||
| { | |||
| return _run(fetches, feed_dict); | |||
| } | |||
| @@ -30,6 +30,11 @@ namespace Tensorflow | |||
| return tensor._handle; | |||
| } | |||
| public static implicit operator Operation(Tensor tensor) | |||
| { | |||
| return tensor.op; | |||
| } | |||
| public static implicit operator Tensor(IntPtr handle) | |||
| { | |||
| return new Tensor(handle); | |||
| @@ -261,7 +261,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| return $"tf.Tensor {name} shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||
| return $"tf.Tensor '{name}' shape=({string.Join(",", shape)}) dtype={dtype.ToString()}"; | |||
| } | |||
| public void Dispose() | |||
| @@ -83,6 +83,13 @@ namespace Tensorflow | |||
| type; | |||
| } | |||
| public static TF_DataType as_ref(this TF_DataType type) | |||
| { | |||
| return (int)type < 100 ? | |||
| (TF_DataType)Enum.Parse(typeof(TF_DataType), ((int)type + 100).ToString()) : | |||
| type; | |||
| } | |||
| public static bool is_complex(this TF_DataType type) | |||
| { | |||
| return type == TF_DataType.TF_COMPLEX || type == TF_DataType.TF_COMPLEX64 || type == TF_DataType.TF_COMPLEX128; | |||
| @@ -7,9 +7,9 @@ namespace Tensorflow | |||
| { | |||
| public class BaseSaverBuilder | |||
| { | |||
| protected int _write_version; | |||
| protected SaverDef.Types.CheckpointFormatVersion _write_version; | |||
| public BaseSaverBuilder(int write_version = 2) | |||
| public BaseSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) | |||
| { | |||
| _write_version = write_version; | |||
| } | |||
| @@ -30,7 +30,7 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| if (_write_version == 2) | |||
| if (_write_version == SaverDef.Types.CheckpointFormatVersion.V2) | |||
| { | |||
| return gen_io_ops.save_v2(filename_tensor, tensor_names.ToArray(), tensor_slices.ToArray(), tensors.ToArray()); | |||
| } | |||
| @@ -60,7 +60,7 @@ namespace Tensorflow | |||
| bool reshape = false, | |||
| bool sharded = false, | |||
| int max_to_keep = 5, | |||
| double keep_checkpoint_every_n_hours = 10000, | |||
| float keep_checkpoint_every_n_hours = 10000, | |||
| string name = "", | |||
| bool restore_sequentially = false, | |||
| string filename = "model", | |||
| @@ -76,7 +76,10 @@ namespace Tensorflow | |||
| if (max_to_keep < 0) | |||
| max_to_keep = 0; | |||
| Python.with<ops.name_scope>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||
| Tensor save_tensor = null; | |||
| Operation restore_op = null; | |||
| return Python.with<ops.name_scope, SaverDef>(new ops.name_scope(name, "save", saveables.Select(x => x.op).ToArray()), scope => | |||
| { | |||
| name = scope; | |||
| @@ -93,14 +96,35 @@ namespace Tensorflow | |||
| else | |||
| { | |||
| if (build_save) | |||
| _AddSaveOps(filename_tensor, saveables); | |||
| save_tensor = _AddSaveOps(filename_tensor, saveables); | |||
| if (build_restore) | |||
| _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||
| restore_op = _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||
| } | |||
| var graph = ops.get_default_graph(); | |||
| var check_collection_list = graph.get_all_collection_keys(); | |||
| foreach (var collection_type in check_collection_list) | |||
| { | |||
| foreach (var element in graph.get_collection(collection_type) as IList<RefVariable>) | |||
| { | |||
| } | |||
| } | |||
| return new SaverDef() | |||
| { | |||
| FilenameTensorName = filename_tensor.name, | |||
| SaveTensorName = save_tensor.name, | |||
| RestoreOpName = restore_op.Name, | |||
| MaxToKeep = max_to_keep, | |||
| Sharded = sharded, | |||
| KeepCheckpointEveryNHours = keep_checkpoint_every_n_hours, | |||
| Version = _write_version | |||
| }; | |||
| }); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| public Tensor _AddSaveOps(Tensor filename_tensor, SaveableObject[] saveables) | |||
| @@ -6,7 +6,7 @@ namespace Tensorflow | |||
| { | |||
| public class BulkSaverBuilder : BaseSaverBuilder, ISaverBuilder | |||
| { | |||
| public BulkSaverBuilder(int write_version = 2) : base(write_version) | |||
| public BulkSaverBuilder(SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2) : base(write_version) | |||
| { | |||
| } | |||
| @@ -14,7 +14,7 @@ namespace Tensorflow | |||
| bool reshape = false, | |||
| bool sharded = false, | |||
| int max_to_keep = 5, | |||
| double keep_checkpoint_every_n_hours = 10000, | |||
| float keep_checkpoint_every_n_hours = 10000, | |||
| string name = "", | |||
| bool restore_sequentially = false, | |||
| string filename = "model", | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.IO; | |||
| using System.Text; | |||
| namespace Tensorflow | |||
| @@ -13,30 +14,33 @@ namespace Tensorflow | |||
| private bool _reshape; | |||
| private bool _sharded; | |||
| private int _max_to_keep; | |||
| private double _keep_checkpoint_every_n_hours; | |||
| private float _keep_checkpoint_every_n_hours; | |||
| private string _name; | |||
| private bool _restore_sequentially; | |||
| private SaverDef _saver_def; | |||
| private ISaverBuilder _builder; | |||
| private bool _allow_empty; | |||
| private bool _is_built; | |||
| private int _write_version; | |||
| private SaverDef.Types.CheckpointFormatVersion _write_version; | |||
| private bool _pad_step_number; | |||
| private string _filename; | |||
| private bool _is_empty; | |||
| private float _next_checkpoint_time; | |||
| private bool _save_relative_paths; | |||
| private bool? _object_restore_saver; | |||
| public Saver(RefVariable[] var_list = null, | |||
| bool reshape = false, | |||
| bool sharded = false, | |||
| int max_to_keep = 5, | |||
| double keep_checkpoint_every_n_hours = 10000, | |||
| float keep_checkpoint_every_n_hours = 10000, | |||
| string name = "", | |||
| bool restore_sequentially = false, | |||
| SaverDef saver_def = null, | |||
| ISaverBuilder builder = null, | |||
| bool defer_build = false, | |||
| bool allow_empty = false, | |||
| int write_version = 2, | |||
| SaverDef.Types.CheckpointFormatVersion write_version = SaverDef.Types.CheckpointFormatVersion.V2, | |||
| bool pad_step_number = false, | |||
| bool save_relative_paths = false, | |||
| string filename = "") | |||
| @@ -56,6 +60,14 @@ namespace Tensorflow | |||
| if (!defer_build) | |||
| build(); | |||
| if(_saver_def != null) | |||
| { | |||
| _check_saver_def(); | |||
| _write_version = _saver_def.Version; | |||
| } | |||
| _save_relative_paths = save_relative_paths; | |||
| _object_restore_saver = null; | |||
| } | |||
| public void build() | |||
| @@ -106,8 +118,56 @@ namespace Tensorflow | |||
| { | |||
| throw new NotImplementedException(""); | |||
| } | |||
| _check_saver_def(); | |||
| _next_checkpoint_time = (float)(DateTime.UtcNow - new DateTime(1970, 1, 1)).TotalSeconds + _saver_def.KeepCheckpointEveryNHours * 3600; | |||
| } | |||
| private void _check_saver_def() | |||
| { | |||
| if (!tf.context.executing_eagerly()) | |||
| { | |||
| if (string.IsNullOrEmpty(_saver_def.SaveTensorName)) | |||
| throw new ValueError($"saver_def must specify the save_tensor_name: {_saver_def}"); | |||
| if (string.IsNullOrEmpty(_saver_def.RestoreOpName)) | |||
| throw new ValueError($"saver_def must specify the restore_op_name: {_saver_def}"); | |||
| } | |||
| } | |||
| public string save(Session sess, | |||
| string save_path, | |||
| string global_step = "", | |||
| string meta_graph_suffix = "meta", | |||
| bool write_meta_graph = true, | |||
| bool write_state = true, | |||
| bool strip_default_attrs = false) | |||
| { | |||
| string latest_filename = "checkpoint"; | |||
| string model_checkpoint_path = ""; | |||
| string checkpoint_file = ""; | |||
| if (!string.IsNullOrEmpty(global_step)) | |||
| { | |||
| } | |||
| else | |||
| { | |||
| checkpoint_file = save_path; | |||
| } | |||
| var save_path_parent = Path.GetDirectoryName(save_path); | |||
| if (!_is_empty) | |||
| { | |||
| /*model_checkpoint_path = sess.run(_saver_def.SaveTensorName, new FeedItem[] { | |||
| new FeedItem(_saver_def.FilenameTensorName, checkpoint_file) | |||
| });*/ | |||
| } | |||
| throw new NotImplementedException(""); | |||
| return model_checkpoint_path; | |||
| } | |||
| } | |||
| } | |||
| @@ -185,18 +185,12 @@ namespace Tensorflow | |||
| /// A `Tensor` that will hold the new value of this variable after | |||
| /// the assignment has completed. | |||
| /// </returns> | |||
| public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | |||
| where T : ITensorOrOperation | |||
| public ITensorOrOperation assign(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | |||
| { | |||
| var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||
| if (read_value) | |||
| return (T)Convert.ChangeType(assign, typeof(T)); | |||
| return (T)Convert.ChangeType(assign.op, typeof(T)); | |||
| } | |||
| public Tensor assign(Tensor value, bool use_locking = false, string name = "") | |||
| { | |||
| return gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | |||
| return assign; | |||
| return assign.op; | |||
| } | |||
| public override string ToString() | |||
| @@ -292,6 +292,27 @@ namespace Tensorflow | |||
| return tf.Session(); | |||
| } | |||
| public static void _run_using_default_session(Operation operation, FeedItem[] feed_dict, Graph graph, Session session) | |||
| { | |||
| if (session == null) | |||
| { | |||
| session = get_default_session(); | |||
| if (session == null) | |||
| throw new ValueError("Cannot execute operation using `run()`: No default " + | |||
| "session is registered. Use `with " + | |||
| "sess.as_default():` or pass an explicit session to " + | |||
| "`run(session=sess)`"); | |||
| } | |||
| if (session.graph != graph) | |||
| throw new ValueError("Cannot use the default session to execute operation: " + | |||
| "the operation's graph is different from the " + | |||
| "session's graph. Pass an explicit session to " + | |||
| "run(session=sess)."); | |||
| session.run(operation, feed_dict); | |||
| } | |||
| public static Func<Operation, Tensor, Tensor[]> get_gradient_function(Operation op) | |||
| { | |||
| if (op.inputs == null) return null; | |||
| @@ -27,6 +27,13 @@ namespace TensorFlowNET.UnitTest | |||
| with<Session>(tf.Session(), sess => | |||
| { | |||
| sess.run(init_op); | |||
| // o some work with the model. | |||
| inc_v1.op.run(); | |||
| dec_v2.op.run(); | |||
| // Save the variables to disk. | |||
| var save_path = saver.save(sess, "/tmp/model.ckpt"); | |||
| Console.WriteLine($"Model saved in path: {save_path}"); | |||
| }); | |||
| } | |||
| } | |||
| @@ -46,6 +46,24 @@ namespace TensorFlowNET.UnitTest | |||
| } | |||
| } | |||
| [TestMethod] | |||
| public void Assign() | |||
| { | |||
| var v1 = tf.get_variable("v1", shape: new TensorShape(3), initializer: tf.zeros_initializer); | |||
| var inc_v1 = v1.assign(v1 + 1.0f); | |||
| // Add an op to initialize the variables. | |||
| var init_op = tf.global_variables_initializer(); | |||
| with<Session>(tf.Session(), sess => | |||
| { | |||
| sess.run(init_op); | |||
| // o some work with the model. | |||
| inc_v1.op.run(); | |||
| }); | |||
| } | |||
| /// <summary> | |||
| /// https://databricks.com/tensorflow/variables | |||
| /// </summary> | |||