diff --git a/src/TensorFlowNET.Core/APIs/tf.math.cs b/src/TensorFlowNET.Core/APIs/tf.math.cs
index d3b24373..5586840c 100644
--- a/src/TensorFlowNET.Core/APIs/tf.math.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.math.cs
@@ -348,6 +348,9 @@ namespace Tensorflow
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
=> math_ops.cast(x, dtype, name);
+ public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
+ => math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
+
public static Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
diff --git a/src/TensorFlowNET.Core/APIs/tf.nn.cs b/src/TensorFlowNET.Core/APIs/tf.nn.cs
index 9ec1ef58..0bc9d0e5 100644
--- a/src/TensorFlowNET.Core/APIs/tf.nn.cs
+++ b/src/TensorFlowNET.Core/APIs/tf.nn.cs
@@ -75,7 +75,7 @@ namespace Tensorflow
///
/// A pair (outputs, state)
public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
- int? sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
+ Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
=> rnn.dynamic_rnn(cell, inputs, sequence_length: sequence_length, dtype: dtype,
parallel_iterations: parallel_iterations, swap_memory: swap_memory,
diff --git a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
index 61da061f..554e9f1a 100644
--- a/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/BasicRNNCell.cs
@@ -24,7 +24,8 @@ namespace Tensorflow
int _num_units;
Func _activation;
- protected override int state_size => _num_units;
+ public override int state_size => _num_units;
+ public override int output_size => _num_units;
public BasicRNNCell(int num_units,
Func activation = null,
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
index 79d1df89..3200e13f 100644
--- a/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn.cs
@@ -24,15 +24,15 @@ namespace Tensorflow.Operations
{
internal class rnn
{
- public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs,
- int? sequence_length = null, Tensor initial_state = null,
+ public static (Tensor, Tensor) dynamic_rnn(RNNCell cell, Tensor inputs_tensor,
+ Tensor sequence_length = null, Tensor initial_state = null,
TF_DataType dtype = TF_DataType.DtInvalid,
int? parallel_iterations = null, bool swap_memory = false, bool time_major = false)
{
with(tf.variable_scope("rnn"), scope =>
{
VariableScope varscope = scope;
- var flat_input = nest.flatten(inputs);
+ var flat_input = nest.flatten(inputs_tensor);
if (!time_major)
{
@@ -42,24 +42,146 @@ namespace Tensorflow.Operations
parallel_iterations = parallel_iterations ?? 32;
- if (sequence_length.HasValue)
+ if (sequence_length != null)
throw new NotImplementedException("dynamic_rnn sequence_length has value");
var batch_size = _best_effort_input_batch_size(flat_input);
+ Tensor state = null;
if (initial_state != null)
- {
- var state = initial_state;
- }
+ state = initial_state;
else
+ state = cell.get_initial_state(batch_size: batch_size, dtype: dtype);
+
+ var inputs = nest.pack_sequence_as(structure: inputs_tensor, flat_sequence: flat_input);
+
+ var (outputs, final_state) = _dynamic_rnn_loop(
+ cell,
+ inputs as Tensor,
+ state,
+ parallel_iterations: parallel_iterations.Value,
+ swap_memory: swap_memory,
+ sequence_length: sequence_length,
+ dtype: dtype);
+ });
+
+ throw new NotImplementedException("");
+ }
+
+ ///
+ /// Internal implementation of Dynamic RNN.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ private static (Tensor, Tensor) _dynamic_rnn_loop(RNNCell cell, Tensor inputs, Tensor initial_state,
+ int parallel_iterations, bool swap_memory, Tensor sequence_length = null, TF_DataType dtype = TF_DataType.DtInvalid)
+ {
+ var state = initial_state;
+ var state_size = cell.state_size;
+
+ var flat_input = nest.flatten(inputs);
+ var flat_output_size = nest.flatten(cell.output_size);
+
+ // Construct an initial output
+ var input_shape = array_ops.shape(flat_input[0]);
+ var time_steps = input_shape.slice(0);
+ var batch_size = _best_effort_input_batch_size(flat_input);
+ var inputs_got_shape = flat_input.Select(input_ => input_.TensorShape.with_rank_at_least(3)).ToArray();
+
+ var dims = inputs_got_shape[0].Dimensions.Take(2).ToArray();
+ var (const_time_steps, const_batch_size) = (dims[0], dims[1]);
+
+ foreach(var shape in inputs_got_shape)
+ {
+ if (shape[2] == -1)
+ throw new ValueError("Input size (depth of inputs) must be accessible via shape inference," +
+ " but saw value None.");
+
+ var got_time_steps = shape.dims[0];
+ var got_batch_size = shape.dims[1];
+
+ if (const_time_steps != got_time_steps)
+ throw new ValueError("Time steps is not the same for all the elements in the input in a " +
+ "batch.");
+
+ if (const_batch_size != got_batch_size)
+ throw new ValueError("Batch_size is not the same for all the elements in the input.");
+ }
+
+ Func _create_zero_arrays = (size_) =>
+ {
+ var size = rnn_cell_impl._concat(batch_size, size_);
+ return array_ops.zeros(
+ array_ops.stack(size), dtype: _infer_state_dtype(dtype, state));
+ };
+
+ // Prepare dynamic conditional copying of state & output
+ var flat_zero_output = flat_output_size.Select(output => _create_zero_arrays(output)).ToArray();
+ var zero_output = nest.pack_sequence_as(structure: cell.output_size, flat_sequence: flat_zero_output);
+
+ Tensor min_sequence_length = null, max_sequence_length = null;
+ if (sequence_length != null)
+ {
+ min_sequence_length = math_ops.reduce_min(sequence_length);
+ max_sequence_length = math_ops.reduce_max(sequence_length);
+ }
+ else
+ {
+ max_sequence_length = time_steps;
+ }
+
+ var time = array_ops.constant(0, dtype: dtypes.int32, name: "time");
+
+ string base_name = null;
+ with(ops.name_scope("dynamic_rnn"), scope => base_name = scope);
+
+ Func _create_ta = (name, element_shape, dtype_) =>
+ {
+ new TensorArray(dtype: dtype_,
+ size: time_steps,
+ element_shape: element_shape,
+ tensor_array_name: base_name + name);
+ throw new NotImplementedException("");
+ };
+
+ bool in_graph_mode = true;
+ if (in_graph_mode)
+ {
+ foreach(var (i, out_size) in enumerate(flat_output_size))
{
- cell.get_initial_state(batch_size: batch_size, dtype: dtype);
+ _create_ta($"output_{i}",
+ new TensorShape(const_batch_size).concatenate(
+ _maybe_tensor_shape_from_tensor(out_size)),
+ _infer_state_dtype(dtype, state));
+
+
+
}
- });
+ }
throw new NotImplementedException("");
}
+ private static TensorShape _maybe_tensor_shape_from_tensor(Tensor shape)
+ => shape.TensorShape;
+
+ private static TensorShape _maybe_tensor_shape_from_tensor(int shape)
+ => new TensorShape(shape);
+
+ private static TF_DataType _infer_state_dtype(TF_DataType explicit_dtype, Tensor state)
+ {
+ if (explicit_dtype != TF_DataType.DtInvalid)
+ return explicit_dtype;
+
+ throw new NotImplementedException("_infer_state_dtype");
+ }
+
///
/// Transposes the batch and time dimensions of a Tensor.
///
diff --git a/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
new file mode 100644
index 00000000..bd210ecd
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/NnOps/rnn_cell_impl.cs
@@ -0,0 +1,57 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+
+namespace Tensorflow.Operations
+{
+ public class rnn_cell_impl
+ {
+ public BasicRNNCell BasicRNNCell(int num_units)
+ => new BasicRNNCell(num_units);
+
+ public static Tensor _concat(Tensor prefix, int suffix, bool @static = false)
+ {
+ var p = prefix;
+ var p_static = tensor_util.constant_value(prefix);
+ if (p.NDims == 0)
+ p = array_ops.expand_dims(p, 0);
+ else if (p.NDims != 1)
+ throw new ValueError($"prefix tensor must be either a scalar or vector, but saw tensor: {p}");
+
+ var s_tensor_shape = new TensorShape(suffix);
+ var s_static = s_tensor_shape.NDim > -1 ?
+ s_tensor_shape.Dimensions :
+ null;
+ var s = s_tensor_shape.is_fully_defined() ?
+ constant_op.constant(s_tensor_shape.Dimensions, dtype: dtypes.int32) :
+ null;
+
+ if (@static)
+ {
+ if (p_static is null) return null;
+ var shape = new TensorShape(p_static).concatenate(s_static);
+ throw new NotImplementedException("RNNCell _concat");
+ }
+ else
+ {
+ if (p is null || s is null)
+ throw new ValueError($"Provided a prefix or suffix of None: {prefix} and {suffix}");
+ return array_ops.concat(new[] { p, s }, 0);
+ }
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/RNNCell.cs b/src/TensorFlowNET.Core/Operations/RNNCell.cs
index 3b841087..57f46e7b 100644
--- a/src/TensorFlowNET.Core/Operations/RNNCell.cs
+++ b/src/TensorFlowNET.Core/Operations/RNNCell.cs
@@ -15,6 +15,7 @@
******************************************************************************/
using System;
+using Tensorflow.Operations;
using Tensorflow.Util;
using static Tensorflow.Python;
@@ -48,7 +49,9 @@ namespace Tensorflow
/// difference between TF and Keras RNN cell.
///
protected bool _is_tf_rnn_cell = false;
- protected virtual int state_size { get; }
+ public virtual int state_size { get; }
+
+ public virtual int output_size { get; }
public RNNCell(bool trainable = true,
string name = null,
@@ -89,12 +92,18 @@ namespace Tensorflow
private Tensor _zero_state_tensors(int state_size, Tensor batch_size, TF_DataType dtype)
{
- nest.map_structure(x =>
+ var output = nest.map_structure(s =>
{
- throw new NotImplementedException("");
+ var c = rnn_cell_impl._concat(batch_size, s);
+ var size = array_ops.zeros(c, dtype: dtype);
+
+ var c_static = rnn_cell_impl._concat(batch_size, s, @static: true);
+ size.set_shape(c_static);
+
+ return size;
}, state_size);
- throw new NotImplementedException("");
+ return output;
}
}
}
diff --git a/src/TensorFlowNET.Core/Operations/TensorArray.cs b/src/TensorFlowNET.Core/Operations/TensorArray.cs
new file mode 100644
index 00000000..858dac47
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/TensorArray.cs
@@ -0,0 +1,54 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Tensorflow.Operations
+{
+ ///
+ /// TensorArray is designed to hide an underlying implementation object
+ /// and as such accesses many of that object's hidden fields.
+ ///
+ /// "Class wrapping dynamic-sized, per-time-step, write-once Tensor arrays.
+ /// This class is meant to be used with dynamic iteration primitives such as
+ /// `while_loop` and `map_fn`. It supports gradient back-propagation via special
+ /// "flow" control flow dependencies.
+ ///
+ public class TensorArray
+ {
+ _GraphTensorArray _implementation;
+
+ public TensorArray(TF_DataType dtype, Tensor size = null, bool? clear_after_read = null, bool? dynamic_size = null,
+ string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
+ bool infer_shape = true, TensorShape element_shape = null,
+ bool colocate_with_first_write_call = true, string name = null)
+ {
+ _implementation = new _GraphTensorArray(dtype,
+ size: size,
+ dynamic_size: dynamic_size,
+ clear_after_read: clear_after_read,
+ tensor_array_name: tensor_array_name,
+ handle: handle,
+ flow: flow,
+ infer_shape: infer_shape,
+ element_shape: element_shape,
+ colocate_with_first_write_call: colocate_with_first_write_call,
+ name: name);
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
new file mode 100644
index 00000000..b4619c05
--- /dev/null
+++ b/src/TensorFlowNET.Core/Operations/_GraphTensorArray.cs
@@ -0,0 +1,102 @@
+/*****************************************************************************
+ Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+******************************************************************************/
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+using static Tensorflow.Python;
+
+namespace Tensorflow.Operations
+{
+ internal class _GraphTensorArray
+ {
+ TF_DataType _dtype;
+
+ ///
+ /// Used to keep track of what tensors the TensorArray should be
+ /// colocated with. We choose to colocate the TensorArray with the
+ /// first tensor written to it.
+ ///
+ bool _colocate_with_first_write_call;
+
+ bool _infer_shape;
+ List _element_shape;
+
+ object _colocate_with;
+
+ public _GraphTensorArray(TF_DataType dtype, Tensor size, bool? dynamic_size = null,
+ bool? clear_after_read = null, string tensor_array_name = null, Tensor handle = null, Tensor flow = null,
+ bool infer_shape = true, TensorShape element_shape = null,
+ bool colocate_with_first_write_call = true, string name = null)
+ {
+ clear_after_read = clear_after_read ?? true;
+ dynamic_size = dynamic_size ?? false;
+
+ _dtype = dtype;
+
+ _colocate_with_first_write_call = colocate_with_first_write_call;
+ if (colocate_with_first_write_call)
+ _colocate_with = new Tensor[0];
+
+ // Record the current static shape for the array elements. The element
+ // shape is defined either by `element_shape` or the shape of the tensor
+ // of the first write. If `infer_shape` is true, all writes checks for
+ // shape equality.
+ if(element_shape == null)
+ {
+ _infer_shape = infer_shape;
+ _element_shape = new List { };
+ }
+ else
+ {
+ _infer_shape = true;
+ _element_shape = new List { };
+ }
+
+ with(ops.name_scope(name, "", new { handle, size, flow }), scope =>
+ {
+ if(handle != null)
+ {
+
+ }
+ else
+ {
+ Func<(Tensor, Tensor)> create = () => gen_data_flow_ops.tensor_array_v3(size,
+ dtype: dtype,
+ element_shape: element_shape,
+ identical_element_shapes: infer_shape,
+ dynamic_size: dynamic_size.Value,
+ clear_after_read: clear_after_read.Value,
+ tensor_array_name: tensor_array_name,
+ name: scope);
+
+ // Construct the TensorArray with an empty device. The first
+ // write into the TensorArray from a Tensor with a set device
+ // will retroactively set the device value of this op.
+ if (colocate_with_first_write_call)
+ {
+ ops.colocate_with(ignore_existing: true);
+ create();
+ }
+ else
+ {
+
+ }
+ }
+ });
+ }
+ }
+}
diff --git a/src/TensorFlowNET.Core/Operations/array_ops.py.cs b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
index 98b36bc6..c3f52cb8 100644
--- a/src/TensorFlowNET.Core/Operations/array_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/array_ops.py.cs
@@ -29,6 +29,17 @@ namespace Tensorflow
public static Tensor prevent_gradient(Tensor input, string message = "", string name = null)
=> gen_array_ops.prevent_gradient(input, message: message, name: name);
+ internal static Tensor constant(object value,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ int[] shape = null,
+ string name = "Const",
+ bool verify_shape = false) => constant_op._constant_impl(value,
+ dtype,
+ shape,
+ name,
+ verify_shape: verify_shape,
+ allow_broadcast: false);
+
public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null)
{
dtype = dtype.as_base_dtype();
diff --git a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
similarity index 63%
rename from src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs
rename to src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
index 17f64a29..2cb9aac6 100644
--- a/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.py.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_data_flow_ops.cs
@@ -27,5 +27,23 @@ namespace Tensorflow
return _op.outputs[0];
}
+
+ public static (Tensor, Tensor) tensor_array_v3(Tensor size, TF_DataType dtype = TF_DataType.DtInvalid,
+ int[] element_shape = null, bool dynamic_size = false, bool clear_after_read = true,
+ bool identical_element_shapes = false, string tensor_array_name = "tensor_array_name", string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("TensorArrayV3", name, new
+ {
+ size,
+ dtype,
+ element_shape,
+ dynamic_size,
+ clear_after_read,
+ identical_element_shapes,
+ tensor_array_name
+ });
+
+ return (null, null);
+ }
}
}
diff --git a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
index 7ab6f858..8ec7e253 100644
--- a/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/gen_math_ops.cs
@@ -238,6 +238,13 @@ namespace Tensorflow
return _op.outputs[0];
}
+ public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
+ {
+ var _op = _op_def_lib._apply_op_helper("Cumsum", name, args: new { x, axis, exclusive, reverse });
+
+ return _op.outputs[0];
+ }
+
///
/// Computes the sum along segments of a tensor.
///
diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs
index 486fea8d..fc8f08ac 100644
--- a/src/TensorFlowNET.Core/Operations/math_ops.cs
+++ b/src/TensorFlowNET.Core/Operations/math_ops.cs
@@ -80,6 +80,17 @@ namespace Tensorflow
});
}
+ public static Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
+ {
+ return with(ops.name_scope(name, "Cumsum", new {x}), scope =>
+ {
+ name = scope;
+ x = ops.convert_to_tensor(x, name: "x");
+
+ return gen_math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
+ });
+ }
+
///
/// Computes Psi, the derivative of Lgamma (the log of the absolute value of
/// `Gamma(x)`), element-wise.
diff --git a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs b/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
deleted file mode 100644
index 72f4b866..00000000
--- a/src/TensorFlowNET.Core/Operations/rnn_cell_impl.cs
+++ /dev/null
@@ -1,8 +0,0 @@
-namespace Tensorflow.Operations
-{
- public class rnn_cell_impl
- {
- public BasicRNNCell BasicRNNCell(int num_units)
- => new BasicRNNCell(num_units);
- }
-}
diff --git a/src/TensorFlowNET.Core/Sessions/BaseSession.cs b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
index 593c5aba..deb82b51 100644
--- a/src/TensorFlowNET.Core/Sessions/BaseSession.cs
+++ b/src/TensorFlowNET.Core/Sessions/BaseSession.cs
@@ -1,329 +1,329 @@
-/*****************************************************************************
- Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-******************************************************************************/
-
-using NumSharp;
-using System;
-using System.Collections;
-using System.Collections.Generic;
-using System.Linq;
-using System.Numerics;
-using System.Text;
-
-namespace Tensorflow
-{
- public class BaseSession
- {
- protected Graph _graph;
- protected bool _opened;
- protected bool _closed;
- protected int _current_version;
- protected byte[] _target;
- protected IntPtr _session;
- public Status Status;
- public Graph graph => _graph;
-
- public BaseSession(string target = "", Graph g = null, SessionOptions opts = null)
- {
- _graph = g is null ? ops.get_default_graph() : g;
-
- _target = UTF8Encoding.UTF8.GetBytes(target);
-
- SessionOptions newOpts = null;
- if (opts == null)
- newOpts = c_api.TF_NewSessionOptions();
-
- Status = new Status();
-
- _session = c_api.TF_NewSession(_graph, opts ?? newOpts, Status);
-
- // dispose newOpts
- if (opts == null)
- c_api.TF_DeleteSessionOptions(newOpts);
-
- Status.Check(true);
- }
-
- public virtual NDArray run(object fetches, params FeedItem[] feed_dict)
- {
- return _run(fetches, feed_dict);
- }
-
- public virtual NDArray run(object fetches, Hashtable feed_dict = null)
- {
- var feed_items = feed_dict == null ? new FeedItem[0] :
- feed_dict.Keys.OfType