| @@ -8,7 +8,8 @@ namespace Tensorflow | |||||
| /// in order to limit function return value | /// in order to limit function return value | ||||
| /// is Tensor or Operation | /// is Tensor or Operation | ||||
| /// </summary> | /// </summary> | ||||
| public interface IReturnTensorOrOperation | |||||
| public interface ITensorOrOperation | |||||
| { | { | ||||
| string Device { get; } | |||||
| } | } | ||||
| } | } | ||||
| @@ -7,7 +7,7 @@ using System.Text; | |||||
| namespace Tensorflow | namespace Tensorflow | ||||
| { | { | ||||
| public partial class Operation : IReturnTensorOrOperation | |||||
| public partial class Operation : ITensorOrOperation | |||||
| { | { | ||||
| private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
| @@ -7,20 +7,20 @@ namespace Tensorflow | |||||
| { | { | ||||
| public class control_flow_ops | public class control_flow_ops | ||||
| { | { | ||||
| public static Operation group(Operation[] inputs, string name = "") | |||||
| public static Operation group<T>(T[] inputs, string name = "") where T : ITensorOrOperation | |||||
| { | { | ||||
| return Python.with<ops.name_scope, Operation>(new ops.name_scope(name, "group_deps", inputs), scope => | return Python.with<ops.name_scope, Operation>(new ops.name_scope(name, "group_deps", inputs), scope => | ||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| // Sorts *inputs according to their devices. | // Sorts *inputs according to their devices. | ||||
| var ops_on_device = new Dictionary<string, List<Operation>>(); | |||||
| var ops_on_device = new Dictionary<string, List<T>>(); | |||||
| foreach (var inp in inputs) | foreach (var inp in inputs) | ||||
| { | { | ||||
| if (ops_on_device.ContainsKey(inp.Device)) | if (ops_on_device.ContainsKey(inp.Device)) | ||||
| ops_on_device[inp.Device].Add(inp); | ops_on_device[inp.Device].Add(inp); | ||||
| else | else | ||||
| ops_on_device[inp.Device] = new List<Operation> { inp }; | |||||
| ops_on_device[inp.Device] = new List<T> { inp }; | |||||
| } | } | ||||
| // 1-level tree. The root node is the returned NoOp node. | // 1-level tree. The root node is the returned NoOp node. | ||||
| @@ -28,12 +28,15 @@ namespace Tensorflow | |||||
| { | { | ||||
| var dev = ops_on_device.Keys.First(); | var dev = ops_on_device.Keys.First(); | ||||
| var deps = ops_on_device.Values.First(); | var deps = ops_on_device.Values.First(); | ||||
| return _GroupControlDeps(dev, deps.ToArray(), name); | |||||
| if (typeof(T).Name == "Operation") | |||||
| return _GroupControlDeps(dev, deps.Select(x => x as Operation).ToArray(), name); | |||||
| else | |||||
| throw new NotImplementedException("control_flow_ops.group"); | |||||
| } | } | ||||
| // 2-level tree. The root node is the returned NoOp node. | // 2-level tree. The root node is the returned NoOp node. | ||||
| // deps contains 1 NoOp node for each device. | // deps contains 1 NoOp node for each device. | ||||
| return null; | |||||
| throw new NotImplementedException("control_flow_ops.group"); | |||||
| }); | }); | ||||
| } | } | ||||
| @@ -14,5 +14,12 @@ namespace Tensorflow | |||||
| return _op; | return _op; | ||||
| } | } | ||||
| public static Tensor[] restore_v2(Tensor prefix, string[] tensor_names, string[] shape_and_slices, TF_DataType[] dtypes, string name = "") | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("RestoreV2", name: name, args: new { prefix, tensor_names, shape_and_slices, dtypes }); | |||||
| return _op.outputs; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -12,7 +12,7 @@ namespace Tensorflow | |||||
| /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | /// A tensor is a generalization of vectors and matrices to potentially higher dimensions. | ||||
| /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | /// Internally, TensorFlow represents tensors as n-dimensional arrays of base datatypes. | ||||
| /// </summary> | /// </summary> | ||||
| public partial class Tensor : IDisposable, IReturnTensorOrOperation | |||||
| public partial class Tensor : IDisposable, ITensorOrOperation | |||||
| { | { | ||||
| private readonly IntPtr _handle; | private readonly IntPtr _handle; | ||||
| @@ -175,6 +175,9 @@ namespace Tensorflow | |||||
| } | } | ||||
| public Operation[] Consumers => consumers(); | public Operation[] Consumers => consumers(); | ||||
| public string Device => op.Device; | |||||
| public Operation[] consumers() | public Operation[] consumers() | ||||
| { | { | ||||
| var output = _as_tf_output(); | var output = _as_tf_output(); | ||||
| @@ -42,7 +42,18 @@ namespace Tensorflow | |||||
| public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) | public virtual Tensor[] bulk_restore(Tensor filename_tensor, SaveableObject[] saveables, int preferred_shard, bool restore_sequentially) | ||||
| { | { | ||||
| throw new NotImplementedException(); | |||||
| var names = new List<string>(); | |||||
| var slices = new List<string>(); | |||||
| var dtypes = new List<TF_DataType>(); | |||||
| foreach (var saveable in saveables) | |||||
| foreach (var spec in saveable.specs) | |||||
| { | |||||
| names.Add(spec.name); | |||||
| slices.Add(spec.slice_spec); | |||||
| dtypes.Add(spec.dtype); | |||||
| } | |||||
| return gen_io_ops.restore_v2(filename_tensor, names.ToArray(), slices.ToArray(), dtypes.ToArray()); | |||||
| } | } | ||||
| public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, | public virtual SaverDef _build_internal(RefVariable[] names_to_saveables, | ||||
| @@ -83,6 +94,9 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (build_save) | if (build_save) | ||||
| _AddSaveOps(filename_tensor, saveables); | _AddSaveOps(filename_tensor, saveables); | ||||
| if (build_restore) | |||||
| _AddRestoreOps(filename_tensor, saveables, restore_sequentially, reshape); | |||||
| } | } | ||||
| }); | }); | ||||
| @@ -94,5 +108,42 @@ namespace Tensorflow | |||||
| var save = save_op(filename_tensor, saveables); | var save = save_op(filename_tensor, saveables); | ||||
| return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); | return control_flow_ops.with_dependencies(new Operation[] { save }, filename_tensor); | ||||
| } | } | ||||
| /// <summary> | |||||
| /// Add operations to restore saveables. | |||||
| /// </summary> | |||||
| /// <param name="filename_tensor"></param> | |||||
| /// <param name="saveables"></param> | |||||
| /// <param name="restore_sequentially"></param> | |||||
| /// <param name="reshape"></param> | |||||
| /// <param name="preferred_shard"></param> | |||||
| /// <param name="name"></param> | |||||
| /// <returns>An Operation that restores the variables.</returns> | |||||
| public Operation _AddRestoreOps(Tensor filename_tensor, | |||||
| SaveableObject[] saveables, | |||||
| bool restore_sequentially, | |||||
| bool reshape, | |||||
| int preferred_shard = -1, | |||||
| string name = "restore_all") | |||||
| { | |||||
| var all_tensors = bulk_restore(filename_tensor, saveables, preferred_shard, restore_sequentially); | |||||
| var assign_ops = new List<Tensor>(); | |||||
| int idx = 0; | |||||
| foreach(var saveable in saveables) | |||||
| { | |||||
| List<TensorShape> shapes = null; | |||||
| if (reshape) | |||||
| { | |||||
| throw new NotImplementedException("_AddRestoreOps"); | |||||
| } | |||||
| var saveable_tensors = all_tensors.Skip(idx).Take(saveable.specs.Length); | |||||
| idx += saveable.specs.Length; | |||||
| assign_ops.Add(saveable.restore(saveable_tensors.ToArray(), shapes == null ? null : shapes.ToArray())); | |||||
| } | |||||
| return control_flow_ops.group(assign_ops.ToArray(), name: name); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -27,5 +27,13 @@ namespace Tensorflow | |||||
| this.specs = specs; | this.specs = specs; | ||||
| this.name = name; | this.name = name; | ||||
| } | } | ||||
| public virtual Tensor restore(Tensor[] restored_tensors, TensorShape[] restored_shapes = null) | |||||
| { | |||||
| var restored_tensor = restored_tensors[0]; | |||||
| return gen_state_ops.assign(op, | |||||
| restored_tensor, | |||||
| validate_shape: restored_shapes == null && tensor_util.to_shape(op.shape).is_fully_defined()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -186,7 +186,7 @@ namespace Tensorflow | |||||
| /// the assignment has completed. | /// the assignment has completed. | ||||
| /// </returns> | /// </returns> | ||||
| public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | public T assign<T>(Tensor value, bool use_locking = false, string name = "", bool read_value = true) | ||||
| where T : IReturnTensorOrOperation | |||||
| where T : ITensorOrOperation | |||||
| { | { | ||||
| var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | var assign = gen_state_ops.assign(_variable, value, use_locking: use_locking, name: name); | ||||
| if (read_value) | if (read_value) | ||||