Browse Source

change constant creation method.

tags/v0.8.0
haiping008 6 years ago
parent
commit
3d7ff13d2c
12 changed files with 240 additions and 28 deletions
  1. +12
    -0
      src/TensorFlowNET.Core/Framework/tf.ops.cs
  2. +10
    -0
      src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs
  3. +3
    -2
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  4. +24
    -16
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Tensors/constant_op.cs
  6. +19
    -1
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  7. +34
    -1
      src/TensorFlowNET.Core/Train/Saving/Saver.cs
  8. +19
    -0
      src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs
  9. +22
    -3
      src/TensorFlowNET.Core/Train/Saving/saver.py.cs
  10. +29
    -0
      test/TensorFlowNET.Examples/MetaGraph.cs
  11. +0
    -4
      test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj
  12. +67
    -0
      test/TensorFlowNET.Examples/python/meta_graph.py

+ 12
- 0
src/TensorFlowNET.Core/Framework/tf.ops.cs View File

@@ -0,0 +1,12 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow
{
public static partial class tf
{
public static object get_collection(string key, string scope = "") => get_default_graph()
.get_collection(key, scope: scope);
}
}

+ 10
- 0
src/TensorFlowNET.Core/Operations/Losses/losses_impl.py.cs View File

@@ -0,0 +1,10 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace Tensorflow.Operations.Losses
{
class losses_impl
{
}
}

+ 3
- 2
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -82,9 +82,9 @@ namespace Tensorflow
return shape_internal(input, name, optimize: true, out_type: out_type);
}

public static Tensor size(Tensor input, string name = "", TF_DataType out_type = TF_DataType.TF_INT32)
public static Tensor size(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
{
return size_internal(input, name, optimize: true, out_type: out_type);
return size_internal(input, name, optimize: optimize, out_type: out_type);
}

private static Tensor shape_internal(Tensor input, string name = "", bool optimize = true, TF_DataType out_type = TF_DataType.TF_INT32)
@@ -132,6 +132,7 @@ namespace Tensorflow
else
{
// result = gen_array_ops.shape();
throw new NotImplementedException("array_ops.size_internal");
}

return null;


+ 24
- 16
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -46,28 +46,36 @@ namespace Tensorflow
var feed_dict_tensor = new Dictionary<object, object>();
var feed_map = new Dictionary<object, object>();

Func<FeedItem, IEnumerable<(object, object)>> feed_fn = (item) =>
{
return new (object, object)[] { (item.Key, item.Value) };
};

// Validate and process feed_dict.
if (feed_dict != null)
{
foreach(var subfeed in feed_dict)
foreach (var feed in feed_dict)
{
var subfeed_t = _graph.as_graph_element(subfeed.Key, allow_tensor: true, allow_operation: false);
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
switch(subfeed.Value)
foreach (var (subfeed, subfeed_val) in feed_fn(feed))
{
case float floatVal:
feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
break;
case int intVal:
feed_dict_tensor[subfeed_t] = (NDArray)intVal;
break;
case string str:
feed_dict_tensor[subfeed_t] = (NDArray)str;
break;
default:
throw new NotImplementedException("_run subfeed");
var subfeed_t = _graph.as_graph_element(subfeed, allow_tensor: true, allow_operation: false);
var subfeed_dtype = subfeed_t.dtype.as_numpy_datatype();
switch (subfeed_val)
{
case float floatVal:
feed_dict_tensor[subfeed_t] = (NDArray)floatVal;
break;
case int intVal:
feed_dict_tensor[subfeed_t] = (NDArray)intVal;
break;
case string str:
feed_dict_tensor[subfeed_t] = (NDArray)str;
break;
default:
throw new NotImplementedException("_run subfeed");
}
feed_map[subfeed_t.name] = (subfeed_t, subfeed_val);
}
feed_map[subfeed_t.name] = (subfeed_t, subfeed.Value);
}
}



+ 1
- 1
src/TensorFlowNET.Core/Tensors/constant_op.cs View File

@@ -24,7 +24,7 @@ namespace Tensorflow
return _constant_impl(value, dtype, shape, name, verify_shape: false, allow_broadcast: true);
}

private static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast)
public static Tensor _constant_impl(object value, TF_DataType dtype, int[] shape, string name, bool verify_shape, bool allow_broadcast)
{
if (tf.context.executing_eagerly())
{


+ 19
- 1
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -7,8 +7,26 @@ namespace Tensorflow
{
public static partial class tf
{
public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name);
// public static Tensor constant(NDArray nd, string name = "Const") => constant_op.constant(nd, name: name);

public static Tensor constant(object value,
TF_DataType dtype = TF_DataType.DtInvalid,
int[] shape = null,
string name = "Const",
bool verify_shape = false) => constant_op._constant_impl(value,
dtype,
shape,
name,
verify_shape: verify_shape,
allow_broadcast: false);

public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "") => array_ops.zeros(shape, dtype, name);

public static Tensor size(Tensor input,
string name = "",
TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input,
name,
optimize: true,
out_type: out_type);
}
}

+ 34
- 1
src/TensorFlowNET.Core/Train/Saving/Saver.cs View File

@@ -55,6 +55,7 @@ namespace Tensorflow
_keep_checkpoint_every_n_hours = keep_checkpoint_every_n_hours;
_name = name;
_restore_sequentially = restore_sequentially;
_saver_def = saver_def;
_builder = builder;
_is_built = false;
_allow_empty = allow_empty;
@@ -122,7 +123,7 @@ namespace Tensorflow
}
else if (_saver_def != null && !string.IsNullOrEmpty(_name))
{
throw new NotImplementedException("");
throw new NotImplementedException("Saver._build");
}

_check_saver_def();
@@ -200,6 +201,38 @@ namespace Tensorflow
return saver._import_meta_graph_with_return_elements(meta_graph_or_file, clear_devices, import_scope);
}

/// <summary>
/// Restores previously saved variables.
///
/// This method runs the ops added by the constructor for restoring variables.
/// It requires a session in which the graph was launched. The variables to
/// restore do not have to have been initialized, as restoring is itself a way
/// to initialize variables.
/// </summary>
/// <param name="sess">A `Session` to use to restore the parameters. None in eager mode.</param>
/// <param name="save_path">Path where parameters were previously saved.</param>
public void restore(Session sess, string save_path)
{
if (_is_empty)
return;

if (string.IsNullOrEmpty(save_path))
throw new ValueError("Can't load save_path when it is None.");

if (!checkpoint_management.checkpoint_exists(save_path))
throw new ValueError($"The passed save_path is not a valid checkpoint: {save_path}");

Console.WriteLine($"Restoring parameters from {save_path}");

if (tf.context.executing_eagerly())
;
else
sess.run(_saver_def.RestoreOpName, new FeedItem[]
{
new FeedItem(_saver_def.FilenameTensorName, save_path)
});
}

/// <summary>
/// Writes `MetaGraphDef` to save_path/filename.
/// </summary>


+ 19
- 0
src/TensorFlowNET.Core/Train/Saving/checkpoint_management.py.cs View File

@@ -3,6 +3,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using static Tensorflow.SaverDef.Types;

namespace Tensorflow
{
@@ -105,5 +106,23 @@ namespace Tensorflow
string suffixed_filename = basename + "." + meta_graph_suffix;
return suffixed_filename;
}

public static bool checkpoint_exists(string checkpoint_prefix)
{
string pathname = _prefix_to_checkpoint_path(checkpoint_prefix, CheckpointFormatVersion.V2);
if (File.Exists(pathname))
return true;
else if (File.Exists(checkpoint_prefix))
return true;
else
return false;
}

private static string _prefix_to_checkpoint_path(string prefix, CheckpointFormatVersion format_version)
{
if (format_version == CheckpointFormatVersion.V2)
return prefix + ".index";
return prefix;
}
}
}

+ 22
- 3
src/TensorFlowNET.Core/Train/Saving/saver.py.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace Tensorflow
@@ -13,25 +14,43 @@ namespace Tensorflow
{
var meta_graph_def = meta_graph.read_meta_graph_file(meta_graph_or_file);

var imported_vars = meta_graph.import_scoped_meta_graph_with_return_elements(
var meta = meta_graph.import_scoped_meta_graph_with_return_elements(
meta_graph_def,
clear_devices: clear_devices,
import_scope: import_scope,
return_elements: return_elements);

var (imported_vars, imported_return_elements) = meta;

var saver = _create_saver_from_imported_meta_graph(
meta_graph_def, import_scope, imported_vars);

return (saver, null);
}

/// <summary>
/// Return a saver for restoring variable values to an imported MetaGraph.
/// </summary>
/// <param name="meta_graph_def"></param>
/// <param name="import_scope"></param>
/// <param name="imported_vars"></param>
/// <returns></returns>
public static Saver _create_saver_from_imported_meta_graph(MetaGraphDef meta_graph_def,
string import_scope,
(Dictionary<string, RefVariable>, ITensorOrOperation[]) imported_vars)
Dictionary<string, RefVariable> imported_vars)
{
if(meta_graph_def.SaverDef != null)
{
throw new NotImplementedException("_create_saver_from_imported_meta_graph");
// Infer the scope that is prepended by `import_scoped_meta_graph`.
string scope = import_scope;
var var_names = imported_vars.Keys.ToArray();
if(var_names.Length > 0)
{
var sample_key = var_names[0];
var sample_var = imported_vars[sample_key];
scope = string.Join("", sample_var.name.Skip(sample_key.Length));
}
return new Saver(saver_def: meta_graph_def.SaverDef, name: scope);
}
else
{


+ 29
- 0
test/TensorFlowNET.Examples/MetaGraph.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow;

namespace TensorFlowNET.Examples
{
public class MetaGraph : Python, IExample
{
public void Run()
{
ImportMetaGraph("my-save-dir/");
}

private void ImportMetaGraph(string dir)
{
with<Session>(tf.Session(), sess =>
{
var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta");
new_saver.restore(sess, dir + "my-model-10000");
var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
var batch_size = tf.size(labels);
var logits = (tf.get_collection("logits") as List<ITensorOrOperation>)[0];
var loss = tf.losses.sparse_softmax_cross_entropy(labels = labels,
logits = logits);
});
}
}
}

+ 0
- 4
test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj View File

@@ -14,8 +14,4 @@
<ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" />
</ItemGroup>

<ItemGroup>
<Folder Include="python\" />
</ItemGroup>

</Project>

+ 67
- 0
test/TensorFlowNET.Examples/python/meta_graph.py View File

@@ -0,0 +1,67 @@

import tensorflow as tf
import math

# Creates an inference graph.
# Hidden 1
images = tf.constant(1.2, tf.float32, shape=[100, 28])
with tf.name_scope("hidden1"):
weights = tf.Variable(
tf.truncated_normal([28, 128],
stddev=1.0 / math.sqrt(float(28))),
name="weights")
biases = tf.Variable(tf.zeros([128]),
name="biases")
hidden1 = tf.nn.relu(tf.matmul(images, weights) + biases)
# Hidden 2
with tf.name_scope("hidden2"):
weights = tf.Variable(
tf.truncated_normal([128, 32],
stddev=1.0 / math.sqrt(float(128))),
name="weights")
biases = tf.Variable(tf.zeros([32]),
name="biases")
hidden2 = tf.nn.relu(tf.matmul(hidden1, weights) + biases)
# Linear
with tf.name_scope("softmax_linear"):
weights = tf.Variable(
tf.truncated_normal([32, 10],
stddev=1.0 / math.sqrt(float(32))),
name="weights")
biases = tf.Variable(tf.zeros([10]),
name="biases")
logits = tf.matmul(hidden2, weights) + biases
tf.add_to_collection("logits", logits)

init_all_op = tf.global_variables_initializer()

with tf.Session() as sess:
# Initializes all the variables.
sess.run(init_all_op)
# Runs to logit.
sess.run(logits)
# Creates a saver.
saver0 = tf.train.Saver()
saver0.save(sess, 'my-save-dir/my-model-10000')
# Generates MetaGraphDef.
saver0.export_meta_graph('my-save-dir/my-model-10000.meta')


# Then later import it and extend it to a training graph.
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
# Addes loss and train.
labels = tf.constant(0, tf.int32, shape=[100], name="labels")
batch_size = tf.size(labels)
logits = tf.get_collection("logits")[0]
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
logits=logits)

tf.summary.scalar('loss', loss)
# Creates the gradient descent optimizer with the given learning rate.
optimizer = tf.train.GradientDescentOptimizer(0.01)

# Runs train_op.
train_op = optimizer.minimize(loss)
sess.run(train_op)

Loading…
Cancel
Save