| @@ -19,7 +19,7 @@ namespace Tensorflow | |||
| public partial class tensorflow | |||
| { | |||
| public Tensor convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid) | |||
| => ops.convert_to_tensor(value, dtype, name, preferred_dtype); | |||
| => ops.convert_to_tensor(value, dtype, name, preferred_dtype: preferred_dtype); | |||
| public Tensor strided_slice(Tensor input, Tensor begin, Tensor end, Tensor strides = null, | |||
| int begin_mask = 0, | |||
| @@ -69,7 +69,7 @@ namespace Tensorflow.Eager | |||
| return placeholder; | |||
| } | |||
| public Tensor AsContatnt(string name = null) | |||
| public Tensor AsConstant(string name = null) | |||
| { | |||
| Tensor constant = null; | |||
| tf_with(ops.control_dependencies(null), delegate | |||
| @@ -29,7 +29,7 @@ namespace Tensorflow.Framework | |||
| { | |||
| indices = ops.convert_to_tensor( | |||
| indices_, name: "indices", dtype: dtypes.int64); | |||
| values = ops.internal_convert_to_tensor(values_, name: "values"); | |||
| values = ops.convert_to_tensor(values_, name: "values"); | |||
| dense_shape = ops.convert_to_tensor( | |||
| dense_shape_, name: "dense_shape", dtype: dtypes.int64); | |||
| }); | |||
| @@ -13,9 +13,6 @@ namespace Tensorflow.Graphs | |||
| /// </summary> | |||
| public class FuncGraph : Graph | |||
| { | |||
| Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| // _handle == IntPtr.Zero ? string.Empty : c_api.StringPiece(c_api.TF_FunctionName(_handle)); | |||
| IntPtr func_handle; | |||
| public string FuncName => _graph_key; | |||
| @@ -42,8 +39,10 @@ namespace Tensorflow.Graphs | |||
| public FuncGraph(string name) : base() | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| building_function = true; | |||
| tf.Context.graph_mode(); | |||
| as_default(); | |||
| } | |||
| @@ -51,7 +50,10 @@ namespace Tensorflow.Graphs | |||
| public FuncGraph(IntPtr handle, string name, Dictionary<string, string> attrs) : base() | |||
| { | |||
| outer_graph = ops.get_default_graph(); | |||
| while (outer_graph.building_function) | |||
| outer_graph = outer_graph.OuterGraph; | |||
| _graph_key = name; | |||
| building_function = true; | |||
| Attrs = attrs; | |||
| // Will to test if FuncGraph has memory leak | |||
| // c_api.TF_DeleteGraph(_handle); | |||
| @@ -108,7 +110,7 @@ namespace Tensorflow.Graphs | |||
| return base.create_op(op_type, inputs, dtypes, input_types, name, attrs, op_def, compute_device); | |||
| } | |||
| Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid) | |||
| public Tensor capture(Tensor tensor, string name = null, TF_DataType shape = TF_DataType.DtInvalid) | |||
| { | |||
| if(tensor is EagerTensor) | |||
| { | |||
| @@ -118,6 +118,9 @@ namespace Tensorflow | |||
| } | |||
| } | |||
| protected Graph outer_graph; | |||
| public Graph OuterGraph => outer_graph; | |||
| public Graph() | |||
| { | |||
| _handle = c_api.TF_NewGraph(); | |||
| @@ -148,7 +148,7 @@ namespace Tensorflow | |||
| else if (default_type_attr_map.ContainsKey(input_arg.TypeAttr)) | |||
| default_dtype = (DataType)default_type_attr_map[input_arg.TypeAttr]; | |||
| var value = ops.internal_convert_to_tensor(values, | |||
| var value = ops.convert_to_tensor(values, | |||
| name: input_name, | |||
| dtype: dtype.as_tf_dtype(), | |||
| as_ref: input_arg.IsRef, | |||
| @@ -66,7 +66,7 @@ namespace Tensorflow | |||
| else | |||
| { | |||
| ops.init_scope(); | |||
| var variable = ops.internal_convert_to_tensor(op, as_ref: true); | |||
| var variable = ops.convert_to_tensor(op, as_ref: true); | |||
| if (variable.dtype.is_ref_dtype()) | |||
| yield return new ReferenceVariableSaveable(variable, "", name); | |||
| else | |||
| @@ -103,7 +103,7 @@ namespace Tensorflow | |||
| if (!var.dtype.is_ref_dtype()) | |||
| tensor = var.GraphElement; | |||
| else | |||
| tensor = ops.internal_convert_to_tensor(var, as_ref: true); | |||
| tensor = ops.convert_to_tensor(var, as_ref: true); | |||
| } | |||
| if (tensor.op.type == "ReadVariableOp") | |||
| @@ -24,6 +24,7 @@ using System.Runtime.InteropServices; | |||
| using System.Threading; | |||
| using Tensorflow.Contexts; | |||
| using Tensorflow.Eager; | |||
| using Tensorflow.Graphs; | |||
| using Tensorflow.Util; | |||
| using static Tensorflow.Binding; | |||
| @@ -101,14 +102,44 @@ namespace Tensorflow | |||
| public static Tensor convert_to_tensor(object value, | |||
| TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = null, | |||
| bool as_ref = false, | |||
| TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
| Context ctx = null) | |||
| { | |||
| return internal_convert_to_tensor(value, | |||
| dtype: dtype, | |||
| name: name, | |||
| preferred_dtype: preferred_dtype, | |||
| as_ref: false); | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = preferred_dtype; | |||
| if (value is EagerTensor eager_tensor) | |||
| { | |||
| if (tf.executing_eagerly()) | |||
| return eager_tensor; | |||
| /*else | |||
| { | |||
| var graph = get_default_graph(); | |||
| if (!graph.building_function) | |||
| throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | |||
| return (graph as FuncGraph).capture(eager_tensor, name: name); | |||
| }*/ | |||
| } | |||
| Tensor ret = value switch | |||
| { | |||
| NDArray nd => constant_op.constant(nd, dtype: dtype, name: name), | |||
| EagerTensor tensor => tensor.dtype == TF_DataType.TF_RESOURCE | |||
| ? tensor.AsPlaceholder(name: name) | |||
| : tensor.AsConstant(name: name), | |||
| Tensor tensor => tensor, | |||
| Tensor[] tensors => array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name), | |||
| RefVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | |||
| ResourceVariable varVal => varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref), | |||
| TensorShape ts => constant_op.constant(ts.dims, dtype: dtype, name: name), | |||
| int[] dims => constant_op.constant(dims, dtype: dtype, name: name), | |||
| string str => constant_op.constant(str, dtype: tf.@string, name: name), | |||
| object[] objects => array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name), | |||
| _ => constant_op.constant(value, dtype: dtype, name: name) | |||
| }; | |||
| return ret; | |||
| } | |||
| @@ -118,9 +149,7 @@ namespace Tensorflow | |||
| } | |||
| public static Tensor internal_convert_to_tensor_or_composite(Tensor value, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) | |||
| { | |||
| return internal_convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); | |||
| } | |||
| => convert_to_tensor(value, dtype: dtype, name: name, as_ref: as_ref); | |||
| /// <summary> | |||
| /// Wrapper for `Graph.control_dependencies()` using the default graph. | |||
| @@ -460,52 +489,12 @@ namespace Tensorflow | |||
| foreach ((int i, object value) in enumerate(values as object[])) | |||
| { | |||
| string n = string.IsNullOrEmpty(name) ? "" : $"{name}_{i}"; | |||
| ret.Add(internal_convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype)); | |||
| ret.Add(convert_to_tensor(value, dtype: dtype, name: n, as_ref: as_ref, preferred_dtype: preferred_dtype)); | |||
| } | |||
| return ret.ToArray(); | |||
| } | |||
| public static Tensor internal_convert_to_tensor(object value, TF_DataType dtype = TF_DataType.DtInvalid, | |||
| string name = null, TF_DataType preferred_dtype = TF_DataType.DtInvalid, | |||
| bool as_ref = false, | |||
| string scope = null) | |||
| { | |||
| if (dtype == TF_DataType.DtInvalid) | |||
| dtype = preferred_dtype; | |||
| switch (value) | |||
| { | |||
| case NDArray nd: | |||
| return constant_op.constant(nd, dtype: dtype, name: name); | |||
| case EagerTensor tensor: | |||
| if (tf.executing_eagerly()) | |||
| return tensor; | |||
| else | |||
| return tensor.dtype == TF_DataType.TF_RESOURCE | |||
| ? tensor.AsPlaceholder(name: name) | |||
| : tensor.AsContatnt(name: name); | |||
| case Tensor tensor: | |||
| return tensor; | |||
| case Tensor[] tensors: | |||
| return array_ops._autopacking_helper(tensors, dtype, name == null ? "packed" : name); | |||
| case RefVariable varVal: | |||
| return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | |||
| case ResourceVariable varVal: | |||
| return varVal._TensorConversionFunction(dtype: dtype, name: name, as_ref: as_ref); | |||
| case TensorShape ts: | |||
| return constant_op.constant(ts.dims, dtype: dtype, name: name); | |||
| case string str: | |||
| return constant_op.constant(value, dtype: tf.@string, name: name); | |||
| case int[] dims: | |||
| return constant_op.constant(dims, dtype: dtype, name: name); | |||
| case object[] objects: | |||
| return array_ops._autopacking_conversion_function(objects, dtype: dtype, name: name); | |||
| default: | |||
| return constant_op.constant(value, dtype: dtype, name: name); | |||
| } | |||
| } | |||
| public static string strip_name_scope(string name, string export_scope = "") | |||
| { | |||
| if (!string.IsNullOrEmpty(export_scope)) | |||
| @@ -17,6 +17,7 @@ | |||
| using NumSharp; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using Tensorflow.Graphs; | |||
| using static Tensorflow.Binding; | |||
| namespace Tensorflow.Keras | |||
| @@ -78,6 +79,12 @@ namespace Tensorflow.Keras | |||
| public Graph get_graph() | |||
| { | |||
| if (tf.Context.executing_eagerly()) | |||
| { | |||
| if (_GRAPH == null) | |||
| _GRAPH = new FuncGraph("keras_graph"); | |||
| return _GRAPH; | |||
| } | |||
| return ops.get_default_graph(); | |||
| } | |||
| @@ -1,6 +1,7 @@ | |||
| using System; | |||
| using Tensorflow.Keras.Utils; | |||
| using static Tensorflow.Binding; | |||
| using static Tensorflow.KerasApi; | |||
| namespace Tensorflow.Keras.Engine | |||
| { | |||
| @@ -22,7 +23,7 @@ namespace Tensorflow.Keras.Engine | |||
| Tensors outputs = null; | |||
| using var ctxManager = CallContext.enter(); | |||
| // using var graph = tf.keras.backend.get_graph().as_default(); | |||
| // using var graph = keras.backend.get_graph(); | |||
| if (!inputs.IsEagerTensor) | |||
| tf.Context.graph_mode(isFunc: true); | |||