diff --git a/README.md b/README.md
index 545cea13..8a9df45c 100644
--- a/README.md
+++ b/README.md
@@ -150,6 +150,8 @@ Example runner will download all the required files like training data and model
* [Named Entity Recognition](test/TensorFlowNET.Examples/TextProcess/NER)
* [Transfer Learning for Image Classification in InceptionV3](test/TensorFlowNET.Examples/ImageProcess/RetrainImageClassifier.cs)
+More troubleshooting of running example refer [here](tensorflowlib/README.md).
+
### Contribute:
Feel like contributing to one of the hottest projects in the Machine Learning field? Want to know how Tensorflow magically creates the computational graph? We appreciate every contribution however small. There are tasks for novices to experts alike, if everyone tackles only a small task the sum of contributions will be huge.
diff --git a/docs/_config.yml b/docs/_config.yml
new file mode 100644
index 00000000..c4192631
--- /dev/null
+++ b/docs/_config.yml
@@ -0,0 +1 @@
+theme: jekyll-theme-cayman
\ No newline at end of file
diff --git a/graph/InceptionV3.meta b/graph/InceptionV3.meta
index 0ded6221..2a11b082 100644
Binary files a/graph/InceptionV3.meta and b/graph/InceptionV3.meta differ
diff --git a/src/TensorFlowNET.Core/APIs/tf.array.cs b/src/TensorFlowNET.Core/APIs/tf.array.cs
index cbcfed28..4e733f18 100644
--- a/src/TensorFlowNET.Core/APIs/tf.array.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.array.cs
@@ -36,6 +36,16 @@ namespace Tensorflow
public static Tensor expand_dims(Tensor input, int axis = -1, string name = null, int dim = -1)
=> array_ops.expand_dims(input, axis, name, dim);
+ ///
+ /// Creates a tensor filled with a scalar value.
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor fill(Tensor dims, T value, string name = null)
+ => gen_array_ops.fill(dims, value, name: name);
+
///
/// Return the elements, either from `x` or `y`, depending on the `condition`.
///
diff --git a/src/TensorFlowNET.Core/APIs/tf.gradients.cs b/src/TensorFlowNET.Core/APIs/tf.gradients.cs
index 77491d55..0961288c 100644
--- a/src/TensorFlowNET.Core/APIs/tf.gradients.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.gradients.cs
@@ -6,7 +6,7 @@ namespace Tensorflow
{
public static partial class tf
{
- public static object gradients(Tensor[] ys,
+ public static Tensor[] gradients(Tensor[] ys,
Tensor[] xs,
Tensor[] grad_ys = null,
string name = "gradients",
@@ -15,7 +15,7 @@ namespace Tensorflow
int? aggregation_method = null,
Tensor[] stop_gradients = null)
{
- return gradients_impl._GradientsHelper(ys,
+ return gradients_util._GradientsHelper(ys,
xs,
grad_ys,
name,
@@ -33,7 +33,7 @@ namespace Tensorflow
int? aggregation_method = null,
Tensor[] stop_gradients = null)
{
- return gradients_impl._GradientsHelper(new Tensor[] { ys },
+ return gradients_util._GradientsHelper(new Tensor[] { ys },
xs,
grad_ys,
name,
@@ -41,5 +41,23 @@ namespace Tensorflow
gate_gradients,
stop_gradients: stop_gradients);
}
+
+ public static Tensor[] gradients(Tensor ys,
+ Tensor xs,
+ Tensor[] grad_ys = null,
+ string name = "gradients",
+ bool colocate_gradients_with_ops = false,
+ bool gate_gradients = false,
+ int? aggregation_method = null,
+ Tensor[] stop_gradients = null)
+ {
+ return gradients_util._GradientsHelper(new Tensor[] { ys },
+ new Tensor[] { xs },
+ grad_ys,
+ name,
+ colocate_gradients_with_ops,
+ gate_gradients,
+ stop_gradients: stop_gradients);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/APIs/tf.layers.cs b/src/TensorFlowNET.Core/APIs/tf.layers.cs
index faf0d089..ea35e869 100644
--- a/src/TensorFlowNET.Core/APIs/tf.layers.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.layers.cs
@@ -142,6 +142,7 @@ namespace Tensorflow
var layer = new Dense(units, activation,
use_bias: use_bias,
+ bias_initializer: bias_initializer,
kernel_initializer: kernel_initializer);
return layer.apply(inputs);
diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index bad41103..a8ec223a 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -257,6 +257,16 @@ namespace Tensorflow
public static Tensor negative(Tensor x, string name = null)
=> gen_math_ops.neg(x, name);
+ ///
+ /// Divides x / y elementwise (using Python 2 division operator semantics).
+ ///
+ ///
+ ///
+ ///
+ ///
+ public static Tensor div(Tensor x, Tensor y, string name = null)
+ => math_ops.div(x, y, name: name);
+
public static Tensor divide(Tensor x, T[] y, string name = null) where T : struct
=> x / ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y");
diff --git a/src/TensorFlowNET.Core/APIs/tf.variable.cs b/src/TensorFlowNET.Core/APIs/tf.variable.cs
index 266d5799..a1b3e1d8 100644
--- a/src/TensorFlowNET.Core/APIs/tf.variable.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.variable.cs
@@ -23,6 +23,8 @@ namespace Tensorflow
TF_DataType dtype = TF_DataType.DtInvalid,
object initializer = null, // IInitializer or Tensor
bool? trainable = null,
+ bool? use_resource = null,
+ bool validate_shape = true,
VariableSynchronization synchronization = VariableSynchronization.Auto,
VariableAggregation aggregation = VariableAggregation.None)
{
@@ -32,6 +34,8 @@ namespace Tensorflow
name,
shape: shape,
dtype: dtype,
+ use_resource: use_resource,
+ validate_shape: validate_shape,
initializer: initializer,
trainable: trainable);
}
diff --git a/src/TensorFlowNET.Core/Framework/CompositeTensor.cs b/src/TensorFlowNET.Core/Framework/CompositeTensor.cs
new file mode 100644
index 00000000..eac74580
--- /dev/null
+++ b/src/TensorFlowNET.Core/Framework/CompositeTensor.cs
@@ -0,0 +1,13 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Framework
+{
+ ///
+ /// Abstract base class for Tensor-like objects that are composed from Tensors.
+ ///
+ public abstract class CompositeTensor
+ {
+ }
+}
diff --git a/src/TensorFlowNET.Core/Framework/IndexedSlices.cs b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
new file mode 100644
index 00000000..0c4f0c8b
--- /dev/null
+++ b/src/TensorFlowNET.Core/Framework/IndexedSlices.cs
@@ -0,0 +1,48 @@
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Framework
+{
+ ///
+ /// A sparse representation of a set of tensor slices at given indices.
+ ///
+ public class IndexedSlices : CompositeTensor
+ {
+ Tensor _values;
+ public Tensor values => _values;
+ Tensor _indices;
+ public Tensor indices => _indices;
+ Tensor _dense_shape;
+ public Tensor dense_shape => _dense_shape;
+
+ public string name => _values.name;
+
+ public string device => _values.Device;
+
+ public Operation op => _values.op;
+
+ public TF_DataType dtype => _values.dtype;
+
+ public Graph graph => _values.graph;
+
+ public IndexedSlices(Tensor values, Tensor indices, Tensor dense_shape = null)
+ {
+ _values = values;
+ _indices = indices;
+ _dense_shape = dense_shape;
+
+ _values.Tag = this;
+ }
+
+ public static implicit operator Tensor(IndexedSlices indexedSlices)
+ {
+ return indexedSlices.values;
+ }
+
+ public static implicit operator IndexedSlices(Tensor tensor)
+ {
+ return tensor.Tag as IndexedSlices;
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Gradients/array_grad.cs b/src/TensorFlowNET.Core/Gradients/array_grad.cs
index b7c5494a..4896d4dd 100644
--- a/src/TensorFlowNET.Core/Gradients/array_grad.cs
+++ b/src/TensorFlowNET.Core/Gradients/array_grad.cs
@@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
+using Tensorflow.Framework;
using Tensorflow.Operations;
using static Tensorflow.Python;
@@ -42,9 +43,9 @@ namespace Tensorflow.Gradients
return end_value_index <= dim_index ? new Tensor[] { grad, null } : new Tensor[] { null, grad };
var concat_dim = op.inputs[dim_index];
- if (end_value_index == -1)
- end_value_index = op.inputs.Length - 1;
- var input_values = op.inputs._inputs.Skip(start_value_index).Take(end_value_index - start_value_index).ToArray();
+ var input_values = op.inputs._inputs.Skip(start_value_index)
+ .Take(end_value_index == -1 ? op.inputs.Length - 1 : end_value_index - start_value_index)
+ .ToArray();
var out_grads = new List();
if (constant_op.is_constant(concat_dim))
@@ -82,20 +83,26 @@ namespace Tensorflow.Gradients
new Tensor[] { non_neg_concat_dim, tf.constant(0) },
new Tensor[] { tf.constant(1), tf.constant(-1) });
var squeeze_sizes = array_ops.squeeze(slice);
- out_grads = gen_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
+ out_grads = gen_array_ops.split(grad, squeeze_sizes, non_neg_concat_dim).ToList();
}
else
{
- var offset = gen_ops.concat_offset(non_neg_concat_dim, sizes);
+ var offset = gen_array_ops.concat_offset(non_neg_concat_dim, sizes);
foreach (var (begin, size) in zip(offset, sizes))
- out_grads.Add(gen_ops.slice(grad, begin, size));
+ out_grads.Add(gen_array_ops.slice(grad, begin, size));
}
return (end_value_index <= dim_index ?
- out_grads.ToArray().Concat(null) :
+ out_grads.ToArray().Concat(new Tensor[] { null }) :
new Tensor[] { null }.Concat(out_grads)).ToArray();
}
+ [RegisterGradient("ExpandDims")]
+ public static Tensor[] _ExpandDimsGrad(Operation op, Tensor[] grads)
+ {
+ return new Tensor[] { _ReshapeToInput(op, grads[0]), null };
+ }
+
///
/// Extract the shapes of a set of input tensors.
///
@@ -122,7 +129,46 @@ namespace Tensorflow.Gradients
if (fully_known)
return sizes;
else
- return gen_ops.shape_n(inputs);
+ return gen_array_ops.shape_n(inputs);
+ }
+
+ ///
+ /// Gradient for GatherV2 op.
+ ///
+ ///
+ ///
+ ///
+ [RegisterGradient("GatherV2")]
+ public static Tensor[] _GatherV2Grad(Operation op, Tensor[] grads)
+ {
+ var grad = grads[0];
+ var @params = op.inputs[0];
+ ops.colocate_with(@params);
+
+ var params_shape = array_ops.shape(@params, out_type: tf.int64);
+ params_shape = math_ops.cast(params_shape, tf.int32);
+
+ var indices = op.inputs[1];
+ var indices_size = array_ops.expand_dims(array_ops.size(indices), 0);
+ var axis = op.inputs[2];
+ var axis_static = tensor_util.constant_value(axis);
+
+ // For axis 0 gathers, build an appropriately shaped IndexedSlices.
+ if((int)axis_static == 0)
+ {
+ var params_tail_shape = params_shape[new NumSharp.Slice(start:1)];
+ var values_shape = array_ops.concat(new[] { indices_size, params_tail_shape }, 0);
+ var values = array_ops.reshape(grad, values_shape);
+ indices = array_ops.reshape(indices, indices_size);
+ return new Tensor[]
+ {
+ new IndexedSlices(values, indices, params_shape),
+ null,
+ null
+ };
+ }
+
+ return new Tensor[] { null, null };
}
[RegisterGradient("Reshape")]
diff --git a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
index 18151ac5..8ad4b44e 100644
--- a/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
+++ b/src/TensorFlowNET.Core/Gradients/gradients_impl.py.cs
@@ -1,5 +1,4 @@
-using NumSharp;
-using System;
+using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
@@ -18,487 +17,7 @@ namespace Tensorflow
bool gate_gradients = false,
int? aggregation_method = null)
{
- return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients);
- }
-
- public static Tensor[] _GradientsHelper(Tensor[] ys,
- Tensor[] xs,
- Tensor[] grad_ys = null,
- string name = "gradients",
- bool colocate_gradients_with_ops = false,
- bool gate_gradients = false,
- int aggregation_method = 0,
- Tensor[] stop_gradients = null,
- Graph src_graph = null)
- {
- if (src_graph == null)
- src_graph = ops.get_default_graph();
-
- // If src_graph is a _FuncGraph (i.e. a function body), gather it and all
- // ancestor graphs. This is necessary for correctly handling captured values.
- var curr_graph = src_graph;
-
- if (stop_gradients == null)
- stop_gradients = new Tensor[0];
- if (grad_ys == null)
- grad_ys = new Tensor[ys.Length];
-
- // Iterate over the collected ops.
- /**
- * grads: op => list of gradients received on each output endpoint of the
- * op. The gradients for each endpoint are initially collected as a list.
- * When it is time to call the op's gradient function, for each endpoint we
- * aggregate the list of received gradients into a Add() Operation if there
- * is more than one.
- **/
- var grads = new Dictionary();
-
- with(ops.name_scope(name, "gradients",
- values: ys.Concat(xs).Concat(stop_gradients).Concat(grad_ys)), scope =>
- {
- string grad_scope = scope;
- // Get a uid for this call to gradients that can be used to help
- // cluster ops for compilation.
- var gradient_uid = ops.get_default_graph().unique_name("uid");
- ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name: "y");
- xs = ops.internal_convert_n_to_tensor_or_indexed_slices(xs, name: "x", as_ref: true);
- grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops, gradient_uid);
-
- /**
- * The approach we take here is as follows: Create a list of all ops in the
- * subgraph between the ys and xs. Visit these ops in reverse order of ids
- * to ensure that when we visit an op the gradients w.r.t its outputs have
- * been collected. Then aggregate these gradients if needed, call the op's
- * gradient function, and add the generated gradients to the gradients for
- * its input.
- **/
-
- // Initialize the pending count for ops in the connected subgraph from ys
- // to the xs.
- var to_ops = ys.Select(x => x.op).ToList();
- var from_ops = xs.Select(x => x.op).ToList();
- var stop_gradient_ops = stop_gradients.Select(x => x.op).ToList();
- (var reachable_to_ops, var pending_count, var loop_state) = _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, new List