| @@ -25,6 +25,12 @@ namespace Tensorflow | |||
| public class train_internal | |||
| { | |||
| public RefVariable create_global_step(Graph graph) | |||
| => TrainingUtil.create_global_step(graph); | |||
| public RefVariable get_global_step(Graph graph) | |||
| => TrainingUtil.get_global_step(graph); | |||
| public Optimizer GradientDescentOptimizer(float learning_rate) | |||
| => new GradientDescentOptimizer(learning_rate); | |||
| @@ -46,6 +46,7 @@ namespace Tensorflow | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| object initializer = null, // IInitializer or Tensor | |||
| bool? trainable = null, | |||
| List<string> collections = null, | |||
| bool? use_resource = null, | |||
| bool validate_shape = true, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| @@ -60,7 +61,8 @@ namespace Tensorflow | |||
| use_resource: use_resource, | |||
| validate_shape: validate_shape, | |||
| initializer: initializer, | |||
| trainable: trainable); | |||
| trainable: trainable, | |||
| collections: collections); | |||
| } | |||
| } | |||
| } | |||
| @@ -1,5 +1,7 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Diagnostics; | |||
| using System.IO; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| @@ -34,32 +36,52 @@ namespace Tensorflow.Estimators | |||
| if(max_steps > 0) | |||
| { | |||
| var start_step = _load_global_step_from_checkpoint_dir(_model_dir); | |||
| if (max_steps <= start_step) | |||
| { | |||
| Console.WriteLine("Skipping training since max_steps has already saved."); | |||
| return this; | |||
| } | |||
| } | |||
| _train_model(); | |||
| _train_model(input_fn); | |||
| throw new NotImplementedException(""); | |||
| } | |||
| private int _load_global_step_from_checkpoint_dir(string checkpoint_dir) | |||
| { | |||
| var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||
| // var cp = tf.train.latest_checkpoint(checkpoint_dir); | |||
| // should use NewCheckpointReader (not implemented) | |||
| var cp = tf.train.get_checkpoint_state(checkpoint_dir); | |||
| return 0; | |||
| return cp.AllModelCheckpointPaths.Count - 1; | |||
| } | |||
| private void _train_model() | |||
| private void _train_model(Action input_fn) | |||
| { | |||
| _train_model_default(); | |||
| _train_model_default(input_fn); | |||
| } | |||
| private void _train_model_default() | |||
| private void _train_model_default(Action input_fn) | |||
| { | |||
| using (var g = tf.Graph().as_default()) | |||
| { | |||
| var global_step_tensor = _create_and_assert_global_step(g); | |||
| } | |||
| } | |||
| private Tensor _create_and_assert_global_step(Graph graph) | |||
| { | |||
| var step = _create_global_step(graph); | |||
| Debug.Assert(step == tf.train.get_global_step(graph)); | |||
| Debug.Assert(step.dtype.is_integer()); | |||
| return step; | |||
| } | |||
| private RefVariable _create_global_step(Graph graph) | |||
| { | |||
| return tf.train.create_global_step(graph); | |||
| } | |||
| public void __init__() | |||
| { | |||
| throw new NotImplementedException(); | |||
| @@ -175,6 +175,10 @@ namespace Tensorflow | |||
| if (_nodes_by_name.ContainsKey(op_name)) | |||
| return _nodes_by_name[op_name].outputs[out_n]; | |||
| else | |||
| throw new KeyError($"The name {name} refers to a Tensor which does not " + | |||
| $"exist. The operation, {op_name}, does not exist in the " + | |||
| "graph."); | |||
| } | |||
| else if (!name.Contains(":") & allow_operation) | |||
| { | |||
| @@ -54,6 +54,8 @@ namespace Tensorflow | |||
| return _constant_if_small(0.0D, shape, dtype, name); | |||
| case TF_DataType.TF_FLOAT: | |||
| return _constant_if_small(0.0F, shape, dtype, name); | |||
| case TF_DataType.TF_INT64: | |||
| return _constant_if_small(0l, shape, dtype, name); | |||
| case TF_DataType.TF_INT32: | |||
| return _constant_if_small(0, shape, dtype, name); | |||
| case TF_DataType.TF_INT8: | |||
| @@ -35,5 +35,6 @@ | |||
| DtFloatRef = 101, // DT_FLOAT_REF | |||
| DtDoubleRef = 102, // DT_DOUBLE_REF | |||
| DtInt32Ref = 103, // DT_INT32_REF | |||
| DtInt64Ref = 109 // DT_INT64_REF | |||
| } | |||
| } | |||
| @@ -246,7 +246,8 @@ namespace Tensorflow | |||
| public static bool is_integer(this TF_DataType type) | |||
| { | |||
| return type == TF_DataType.TF_INT8 || type == TF_DataType.TF_INT16 || type == TF_DataType.TF_INT32 || type == TF_DataType.TF_INT64 || | |||
| type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64; | |||
| type == TF_DataType.TF_UINT8 || type == TF_DataType.TF_UINT16 || type == TF_DataType.TF_UINT32 || type == TF_DataType.TF_UINT64 || | |||
| type == TF_DataType.DtInt32Ref || type == TF_DataType.DtInt64Ref; | |||
| } | |||
| public static bool is_floating(this TF_DataType type) | |||
| @@ -249,7 +249,9 @@ namespace Tensorflow | |||
| { | |||
| _maybe_initialize_trackable(); | |||
| v = variable_scope.default_variable_creator( | |||
| initial_value, name: name, trainable: false, | |||
| initial_value, | |||
| name: name, | |||
| trainable: false, | |||
| use_resource: resource_variable_ops.is_resource_variable( | |||
| colocate_with)); | |||
| @@ -174,8 +174,24 @@ namespace Tensorflow | |||
| var coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir, latest_filename); | |||
| if (File.Exists(coord_checkpoint_filename)) | |||
| { | |||
| var file_content = File.ReadAllBytes(coord_checkpoint_filename); | |||
| var ckpt = CheckpointState.Parser.ParseFrom(file_content); | |||
| var file_content = File.ReadAllLines(coord_checkpoint_filename); | |||
| // https://github.com/protocolbuffers/protobuf/issues/6654 | |||
| // var ckpt = CheckpointState.Parser.ParseFrom(file_content); | |||
| var ckpt = new CheckpointState(); | |||
| var field = CheckpointState.Descriptor.FindFieldByName("model_checkpoint_path"); | |||
| ckpt.ModelCheckpointPath = file_content.FirstOrDefault(x => x.StartsWith(field.Name + ":")).Substring(field.Name.Length + 2); | |||
| // remove first and last quote. | |||
| ckpt.ModelCheckpointPath = ckpt.ModelCheckpointPath.Substring(1, ckpt.ModelCheckpointPath.Length - 2); | |||
| field = CheckpointState.Descriptor.FindFieldByName("all_model_checkpoint_paths"); | |||
| file_content.Where(x => x.StartsWith(field.Name + ":")) | |||
| .ToList() | |||
| .ForEach(x => | |||
| { | |||
| string value = x.Substring(field.Name.Length + 2); | |||
| ckpt.AllModelCheckpointPaths.Add(value.Substring(1, value.Length - 2)); | |||
| }); | |||
| if (string.IsNullOrEmpty(ckpt.ModelCheckpointPath)) | |||
| throw new ValueError($"Invalid checkpoint state loaded from {checkpoint_dir}"); | |||
| // For relative model_checkpoint_path and all_model_checkpoint_paths, | |||
| @@ -0,0 +1,51 @@ | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Text; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Train | |||
| { | |||
| public class TrainingUtil | |||
| { | |||
| public static RefVariable create_global_step(Graph graph) | |||
| { | |||
| graph = graph ?? ops.get_default_graph(); | |||
| if (get_global_step(graph) != null) | |||
| throw new ValueError("global_step already exists."); | |||
| // Create in proper graph and base name_scope. | |||
| var g = graph.as_default(); | |||
| g.name_scope(null); | |||
| var v = tf.get_variable(tf.GraphKeys.GLOBAL_STEP, new TensorShape(), dtype: dtypes.int64, | |||
| initializer: tf.zeros_initializer, | |||
| trainable: false, | |||
| aggregation: VariableAggregation.OnlyFirstReplica, | |||
| collections: new List<string> { tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.GLOBAL_STEP }); | |||
| return v; | |||
| } | |||
| public static RefVariable get_global_step(Graph graph) | |||
| { | |||
| graph = graph ?? ops.get_default_graph(); | |||
| RefVariable global_step_tensor = null; | |||
| var global_step_tensors = graph.get_collection<RefVariable>(tf.GraphKeys.GLOBAL_STEP); | |||
| if (global_step_tensors.Count == 1) | |||
| { | |||
| global_step_tensor = global_step_tensors[0]; | |||
| } | |||
| else | |||
| { | |||
| try | |||
| { | |||
| global_step_tensor = graph.get_tensor_by_name("global_step:0"); | |||
| } | |||
| catch (KeyError) | |||
| { | |||
| return null; | |||
| } | |||
| } | |||
| return global_step_tensor; | |||
| } | |||
| } | |||
| } | |||
| @@ -14,6 +14,7 @@ | |||
| limitations under the License. | |||
| ******************************************************************************/ | |||
| using System.Collections.Generic; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow | |||
| @@ -50,6 +51,7 @@ namespace Tensorflow | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| object initializer = null, // IInitializer or Tensor | |||
| bool? trainable = null, | |||
| List<string> collections = null, | |||
| bool? use_resource = null, | |||
| bool validate_shape = true, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| @@ -67,6 +69,7 @@ namespace Tensorflow | |||
| initializer: initializer, | |||
| reuse: resue, | |||
| trainable: trainable, | |||
| collections: collections, | |||
| synchronization: synchronization, | |||
| aggregation: aggregation); | |||
| }); | |||
| @@ -42,6 +42,7 @@ namespace Tensorflow | |||
| object initializer = null, // IInitializer or Tensor | |||
| bool? reuse = null, | |||
| bool? trainable = null, | |||
| List<string> collections = null, | |||
| bool validate_shape = true, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| VariableAggregation aggregation = VariableAggregation.None) | |||
| @@ -54,6 +55,7 @@ namespace Tensorflow | |||
| dtype: dtype, | |||
| initializer: initializer, | |||
| trainable: trainable, | |||
| collections: collections, | |||
| validate_shape: validate_shape, | |||
| synchronization: synchronization, | |||
| aggregation: aggregation); | |||
| @@ -64,6 +66,7 @@ namespace Tensorflow | |||
| TF_DataType dtype = TF_DataType.TF_FLOAT, | |||
| object initializer = null, | |||
| bool? trainable = null, | |||
| List<string> collections = null, | |||
| bool validate_shape = true, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| VariableAggregation aggregation = VariableAggregation.None) | |||
| @@ -77,6 +80,7 @@ namespace Tensorflow | |||
| dtype: dtype, | |||
| initializer: init, | |||
| trainable: trainable, | |||
| collections: collections, | |||
| validate_shape: validate_shape, | |||
| synchronization: synchronization, | |||
| aggregation: aggregation); | |||
| @@ -112,6 +116,7 @@ namespace Tensorflow | |||
| IInitializer initializer = null, | |||
| bool reuse = false, | |||
| bool? trainable = null, | |||
| List<string> collections = null, | |||
| bool validate_shape = false, | |||
| bool? use_resource = null, | |||
| VariableSynchronization synchronization = VariableSynchronization.Auto, | |||
| @@ -157,6 +162,7 @@ namespace Tensorflow | |||
| v = variable_scope.default_variable_creator(init_val, | |||
| name: name, | |||
| trainable: trainable, | |||
| collections: collections, | |||
| dtype: variable_dtype, | |||
| validate_shape: validate_shape, | |||
| synchronization: synchronization, | |||
| @@ -175,6 +175,7 @@ namespace Tensorflow | |||
| public static RefVariable default_variable_creator(object initial_value, | |||
| string name = null, | |||
| bool? trainable = null, | |||
| List<string> collections = null, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| bool validate_shape = false, | |||
| bool ? use_resource = null, | |||
| @@ -199,6 +200,7 @@ namespace Tensorflow | |||
| return new RefVariable(initial_value, | |||
| trainable: trainable.Value, | |||
| validate_shape: validate_shape, | |||
| collections: collections, | |||
| name: name, | |||
| dtype: dtype); | |||
| } | |||