Browse Source

#175

tags/v0.8.0
Oceania2018 6 years ago
parent
commit
8dfc3b7cad
13 changed files with 105 additions and 60 deletions
  1. +16
    -19
      src/TensorFlowNET.Core/Gradients/math_grad.py.cs
  2. +7
    -0
      src/TensorFlowNET.Core/Graphs/Graph.cs
  3. +37
    -0
      src/TensorFlowNET.Core/Operations/array_ops.py.cs
  4. +4
    -4
      src/TensorFlowNET.Core/Operations/math_ops.py.cs
  5. +0
    -25
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  6. +2
    -0
      src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs
  7. +6
    -1
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  8. +1
    -1
      src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs
  9. +9
    -2
      src/TensorFlowNET.Core/Train/Optimizer.cs
  10. +1
    -1
      src/TensorFlowNET.Core/Train/tf.optimizers.cs
  11. +2
    -0
      src/TensorFlowNET.Core/ops.GraphKeys.cs
  12. +5
    -0
      src/TensorFlowNET.Core/ops.py.cs
  13. +15
    -7
      test/TensorFlowNET.Examples/LinearRegression.cs

+ 16
- 19
src/TensorFlowNET.Core/Gradients/math_grad.py.cs View File

@@ -72,10 +72,7 @@ namespace Tensorflow


public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad) public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
{ {
if (x.NDims == 0 && y.NDims == 0 && grad.NDims == 0) return true;

return string.Join(",", x.shape).Equals(string.Join(",", y.shape)) &&
string.Join(",", x.shape).Equals(string.Join(",", grad.shape));
return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1;
} }


public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad) public static (Tensor, Tensor) _SumGrad(Operation op, Tensor grad)
@@ -110,14 +107,15 @@ namespace Tensorflow
x = math_ops.conj(x); x = math_ops.conj(x);
y = math_ops.conj(y); y = math_ops.conj(y);


var realdiv1 = gen_math_ops.real_div(grad, y);
var reduce_sum1 = math_ops.reduce_sum(realdiv1, rx);
var realdiv2 = gen_math_ops.real_div(-x, y);
var realdiv3 = gen_math_ops.real_div(realdiv2, y);
var mul = grad * realdiv3;
var reduce_sum2 = math_ops.reduce_sum(mul, ry);
var realdiv1 = gen_math_ops.real_div(-x, y);
var realdiv2 = gen_math_ops.real_div(realdiv1, y);
var reduce_sum1 = math_ops.reduce_sum(grad * realdiv2, ry);
var reshape1 = gen_array_ops.reshape(reduce_sum1, sy);
var realdiv3 = gen_math_ops.real_div(grad, y);
var reduce_sum2 = math_ops.reduce_sum(realdiv3, rx);
var reshape2 = gen_array_ops.reshape(reduce_sum2, sx);


return (gen_array_ops.reshape(reduce_sum1, sx), gen_array_ops.reshape(reduce_sum2, sy));
return (reshape2, reshape1);
} }


public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad) public static (Tensor, Tensor) _PowGrad(Operation op, Tensor grad)
@@ -135,17 +133,16 @@ namespace Tensorflow
var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx); var gx = gen_array_ops.reshape(math_ops.reduce_sum(grad * y * gen_math_ops.pow(x, y - 1.0), rx), sx);
Tensor log_x = null; Tensor log_x = null;
// Avoid false singularity at x = 0 // Avoid false singularity at x = 0
Tensor mask = null;
if (x.dtype.is_complex()) if (x.dtype.is_complex())
{
throw new NotImplementedException("x.dtype.is_complex()"); throw new NotImplementedException("x.dtype.is_complex()");
}
else else
{
var x1 = gen_array_ops.log(x);
var y1 = array_ops.zeros_like(x);
log_x = array_ops.where(x > 0.0, x1, y1);
}
mask = x > 0.0f;
var ones = array_ops.ones_like(x);
var safe_x = array_ops.where(mask, x, ones);
var x1 = gen_array_ops.log(safe_x);
var y1 = array_ops.zeros_like(x);
log_x = array_ops.where(mask, x1, y1);
var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy); var gy = gen_array_ops.reshape(math_ops.reduce_sum(grad * z * log_x, ry), sy);


return (gx, gy); return (gx, gy);


+ 7
- 0
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -357,6 +357,13 @@ namespace Tensorflow
return _collections.ContainsKey(name) ? _collections[name] : null; return _collections.ContainsKey(name) ? _collections[name] : null;
} }


public object get_collection_ref(string name)
{
if (!_collections.ContainsKey(name))
_collections[name] = new List<object>();
return _collections[name];
}

public void Dispose() public void Dispose()
{ {
c_api.TF_DeleteGraph(_handle); c_api.TF_DeleteGraph(_handle);


+ 37
- 0
src/TensorFlowNET.Core/Operations/array_ops.py.cs View File

@@ -55,6 +55,43 @@ namespace Tensorflow
return math_ops.rank_internal(input, name, optimize: true); return math_ops.rank_internal(input, name, optimize: true);
} }


/// <summary>
/// Creates a tensor with all elements set to 1.
/// </summary>
/// <param name="tensor"></param>
/// <param name="dtype"></param>
/// <param name="name"></param>
/// <param name="optimize"></param>
/// <returns></returns>
public static Tensor ones_like<T>(T tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = "", bool optimize = true)
=> ones_like_impl(tensor, dtype, name, optimize);

private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true)
{
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones_like", new { tensor }), scope =>
{
name = scope;
var tensor1 = ops.convert_to_tensor(tensor, name: "tensor");
var ones_shape = shape_internal(tensor1, optimize: optimize);
if (dtype == TF_DataType.DtInvalid)
dtype = tensor1.dtype;
var ret = ones(ones_shape, dtype: dtype, name: name);
ret.shape = tensor1.shape;
return ret;
});
}

public static Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = "")
{
dtype = dtype.as_base_dtype();
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "ones", new { shape }), scope =>
{
name = scope;
var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name);
return output;
});
}

public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = "") public static Tensor where(Tensor condition, Tensor x = null, Tensor y = null, string name = "")
{ {
if( x == null && y == null) if( x == null && y == null)


+ 4
- 4
src/TensorFlowNET.Core/Operations/math_ops.py.cs View File

@@ -111,7 +111,7 @@ namespace Tensorflow
if (delta == null) if (delta == null)
delta = 1; delta = 1;


return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Range", new object[] { start, limit, delta }), scope =>
{ {
name = scope; name = scope;
var start1 = ops.convert_to_tensor(start, name: "start"); var start1 = ops.convert_to_tensor(start, name: "start");
@@ -124,15 +124,15 @@ namespace Tensorflow


public static Tensor floordiv(Tensor x, Tensor y, string name = "") public static Tensor floordiv(Tensor x, Tensor y, string name = "")
{ {
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "floordiv", new object[] { }), scope =>
return with<ops.name_scope, Tensor>(new ops.name_scope("", "floordiv", new { x, y }), scope =>
{ {
return gen_math_ops.floor_div(x, y, name);
return gen_math_ops.floor_div(x, y, scope);
}); });
} }


public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true) public static Tensor rank_internal(Tensor input, string name = "", bool optimize = true)
{ {
return Python.with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope =>
return with<ops.name_scope, Tensor>(new ops.name_scope(name, "Rank", new List<Tensor> { input }), scope =>
{ {
name = scope; name = scope;
var input_tensor = ops.convert_to_tensor(input); var input_tensor = ops.convert_to_tensor(input);


+ 0
- 25
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -63,31 +63,6 @@ namespace Tensorflow
break; break;
case "Single": case "Single":
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size); Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
/*if (nd.size > 1)
{
var bb = nd.Data<byte>();
var bytes = Marshal.AllocHGlobal(bb.Length);
Marshal.Copy(bb, 0, bytes, bb.Length);
ulong bytes_len = c_api.TF_StringEncodedSize((ulong)bb.Length);
var dataTypeByte = ToTFDataType(nd.dtype);
// shape
var dims2 = nd.shape.Select(x => (long)x).ToArray();

var tfHandle2 = c_api.TF_AllocateTensor(dataTypeByte,
dims2,
nd.ndim,
bytes_len + sizeof(Int64));

dotHandle = c_api.TF_TensorData(tfHandle2);
Marshal.WriteInt64(dotHandle, 0);
c_api.TF_StringEncode(bytes, (ulong)bb.Length, dotHandle + sizeof(Int64), bytes_len, status);
return tfHandle2;
}
else
{
Marshal.Copy(nd1.Data<float>(), 0, dotHandle, nd.size);
}*/
break; break;
case "Double": case "Double":
Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size); Marshal.Copy(nd1.Data<double>(), 0, dotHandle, nd.size);


+ 2
- 0
src/TensorFlowNET.Core/Tensors/Tensor.Operators.cs View File

@@ -27,8 +27,10 @@ namespace Tensorflow
public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y); public static Tensor operator %(Tensor x, Tensor y) => BinaryOpWrapper("mod", x, y);


public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y); public static Tensor operator >(Tensor x, int y) => gen_array_ops.greater(x, y);
public static Tensor operator >(Tensor x, float y) => gen_array_ops.greater(x, y);
public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y); public static Tensor operator >(Tensor x, double y) => gen_array_ops.greater(x, y);
public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y); public static Tensor operator <(Tensor x, int y) => gen_array_ops.less(x, y);
public static Tensor operator <(Tensor x, float y) => gen_array_ops.less(x, y);
public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y); public static Tensor operator <(Tensor x, double y) => gen_array_ops.less(x, y);


private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y)


+ 6
- 1
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -68,7 +68,12 @@ namespace Tensorflow
c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status); c_api.TF_GraphSetTensorShape(this.Graph, this._as_tf_output(), value, value.Length, status);
} }
} }

public int[] _shape_tuple()
{
return null;
}

/// <summary> /// <summary>
/// number of dimensions /// number of dimensions
/// 0 Scalar (magnitude only) /// 0 Scalar (magnitude only)


+ 1
- 1
src/TensorFlowNET.Core/Train/GradientDescentOptimizer.cs View File

@@ -6,7 +6,7 @@ namespace Tensorflow
{ {
public class GradientDescentOptimizer : Optimizer public class GradientDescentOptimizer : Optimizer
{ {
public GradientDescentOptimizer(double learning_rate, bool use_locking = false, string name = "GradientDescent")
public GradientDescentOptimizer(float learning_rate, bool use_locking = false, string name = "GradientDescent")
: base(learning_rate, use_locking, name) : base(learning_rate, use_locking, name)
{ {
LearningRate = learning_rate; LearningRate = learning_rate;


+ 9
- 2
src/TensorFlowNET.Core/Train/Optimizer.cs View File

@@ -20,14 +20,14 @@ namespace Tensorflow
public static int GATE_GRAPH = 2; public static int GATE_GRAPH = 2;


public string Name { get; set; } public string Name { get; set; }
public double LearningRate { get; set; }
public float LearningRate { get; set; }
public Tensor LearningRateTensor { get; set; } public Tensor LearningRateTensor { get; set; }
public bool _use_locking; public bool _use_locking;
public Dictionary<string, object> _slots; public Dictionary<string, object> _slots;
public Dictionary<string, object> _non_slot_dict; public Dictionary<string, object> _non_slot_dict;
public Dictionary<string, object> _deferred_slot_restorations; public Dictionary<string, object> _deferred_slot_restorations;


public Optimizer(double learning_rate, bool use_locking, string name = "")
public Optimizer(float learning_rate, bool use_locking, string name = "")
{ {
if (String.IsNullOrEmpty(name)) if (String.IsNullOrEmpty(name))
throw new NotImplementedException("Must specify the optimizer name"); throw new NotImplementedException("Must specify the optimizer name");
@@ -114,6 +114,13 @@ namespace Tensorflow


} }


if (!tf.context.executing_eagerly())
{
var train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) as List<object>;
if (!train_op.Contains(apply_updates))
train_op.Add(apply_updates);
}

return apply_updates; return apply_updates;
}); });
} }


+ 1
- 1
src/TensorFlowNET.Core/Train/tf.optimizers.cs View File

@@ -9,7 +9,7 @@ namespace Tensorflow
{ {
public static class train public static class train
{ {
public static Optimizer GradientDescentOptimizer(double learning_rate) => new GradientDescentOptimizer(learning_rate);
public static Optimizer GradientDescentOptimizer(float learning_rate) => new GradientDescentOptimizer(learning_rate);


public static Saver Saver() => new Saver(); public static Saver Saver() => new Saver();




+ 2
- 0
src/TensorFlowNET.Core/ops.GraphKeys.cs View File

@@ -33,6 +33,8 @@ namespace Tensorflow
/// </summary> /// </summary>
public static string GLOBAL_VARIABLES = "variables"; public static string GLOBAL_VARIABLES = "variables";


public static string TRAIN_OP = "train_op";

public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" }; public static string[] _VARIABLE_COLLECTIONS = new string[] { "variables", "trainable_variables" };
/// <summary> /// <summary>
/// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing. /// Key to collect BaseSaverBuilder.SaveableObject instances for checkpointing.


+ 5
- 0
src/TensorFlowNET.Core/ops.py.cs View File

@@ -45,6 +45,11 @@ namespace Tensorflow
return get_default_graph().get_collection(key, scope); return get_default_graph().get_collection(key, scope);
} }


public static object get_collection_ref(string key)
{
return get_default_graph().get_collection_ref(key);
}

private static Graph default_graph; private static Graph default_graph;
public static Graph get_default_graph() public static Graph get_default_graph()
{ {


+ 15
- 7
test/TensorFlowNET.Examples/LinearRegression.cs View File

@@ -21,7 +21,7 @@ namespace TensorFlowNET.Examples
// Parameters // Parameters
float learning_rate = 0.01f; float learning_rate = 0.01f;
int training_epochs = 1000; int training_epochs = 1000;
int display_step = 1;
int display_step = 10;


// Training Data // Training Data
var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f, var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
@@ -29,9 +29,9 @@ namespace TensorFlowNET.Examples
var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f, var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f); 2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
var n_samples = train_X.shape[0]; var n_samples = train_X.shape[0];
// tf Graph Input // tf Graph Input
var X = tf.placeholder(tf.float32);
/*var X = tf.placeholder(tf.float32);
var Y = tf.placeholder(tf.float32); var Y = tf.placeholder(tf.float32);


// Set model weights // Set model weights
@@ -55,7 +55,14 @@ namespace TensorFlowNET.Examples
// radient descent // radient descent
// Note, minimize() knows to modify W and b because Variable objects are trainable=True by default // Note, minimize() knows to modify W and b because Variable objects are trainable=True by default
var grad = tf.train.GradientDescentOptimizer(learning_rate); var grad = tf.train.GradientDescentOptimizer(learning_rate);
var optimizer = grad.minimize(cost);
var optimizer = grad.minimize(cost);*/

var new_saver = tf.train.import_meta_graph("save_model.meta", import_scope: "import");

var X = graph.OperationByName("Placeholder");
var Y = graph.OperationByName("Placeholder_1");
var W = graph.OperationByName("weight");
var optimizer = graph.OperationByName("GradientDescent");


// Initialize the variables (i.e. assign their default value) // Initialize the variables (i.e. assign their default value)
var init = tf.global_variables_initializer(); var init = tf.global_variables_initializer();
@@ -71,14 +78,15 @@ namespace TensorFlowNET.Examples
{ {
foreach (var (x, y) in zip<float>(train_X, train_Y)) foreach (var (x, y) in zip<float>(train_X, train_Y))
{ {
var w = sess.run(W);
sess.run(optimizer, sess.run(optimizer,
new FeedItem(X, x), new FeedItem(X, x),
new FeedItem(Y, y)); new FeedItem(Y, y));
var w = sess.run(W);
w = sess.run(W);
} }


// Display logs per epoch step // Display logs per epoch step
if ((epoch + 1) % display_step == 0)
/*if ((epoch + 1) % display_step == 0)
{ {
var c = sess.run(cost, var c = sess.run(cost,
new FeedItem(X, train_X), new FeedItem(X, train_X),
@@ -86,7 +94,7 @@ namespace TensorFlowNET.Examples
var rW = sess.run(W); var rW = sess.run(W);
Console.WriteLine($"Epoch: {epoch + 1} cost={c} " + Console.WriteLine($"Epoch: {epoch + 1} cost={c} " +
$"W={rW} b={sess.run(b)}"); $"W={rW} b={sess.run(b)}");
}
}*/
} }


Console.WriteLine("Optimization Finished!"); Console.WriteLine("Optimization Finished!");


Loading…
Cancel
Save