| @@ -310,17 +310,25 @@ namespace Tensorflow | |||||
| private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | private static Tensor BinaryOpWrapper<Tx, Ty>(string name, Tx x, Ty y) | ||||
| { | { | ||||
| TF_DataType dtype = TF_DataType.DtInvalid; | TF_DataType dtype = TF_DataType.DtInvalid; | ||||
| bool switchToGraphModeTemp = !tf.executing_eagerly(); | |||||
| if (x is Tensor tl) | if (x is Tensor tl) | ||||
| { | |||||
| dtype = tl.dtype.as_base_dtype(); | dtype = tl.dtype.as_base_dtype(); | ||||
| switchToGraphModeTemp = switchToGraphModeTemp || !tl.IsEagerTensor; | |||||
| } | |||||
| if (y is Tensor tr) | if (y is Tensor tr) | ||||
| { | |||||
| dtype = tr.dtype.as_base_dtype(); | dtype = tr.dtype.as_base_dtype(); | ||||
| if (name == "div") | |||||
| name = div_or_truediv(name, x, y); | |||||
| switchToGraphModeTemp = switchToGraphModeTemp || !tr.IsEagerTensor; | |||||
| } | |||||
| return tf_with(ops.name_scope(null, name, new { x, y }), scope => | return tf_with(ops.name_scope(null, name, new { x, y }), scope => | ||||
| { | { | ||||
| if (switchToGraphModeTemp) | |||||
| tf.Context.graph_mode(); | |||||
| Tensor result; | Tensor result; | ||||
| var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | var x1 = ops.convert_to_tensor(x, dtype: dtype, name: "x"); | ||||
| var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); | var y1 = ops.convert_to_tensor(y, dtype: dtype, name: "y"); | ||||
| @@ -352,6 +360,9 @@ namespace Tensorflow | |||||
| throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); | throw new NotImplementedException($"BinaryOpWrapper: {name} - {typeof(Tx).Name}, {typeof(Ty).Name}"); | ||||
| } | } | ||||
| if (switchToGraphModeTemp) | |||||
| tf.Context.restore_mode(); | |||||
| return result; | return result; | ||||
| }); | }); | ||||
| } | } | ||||
| @@ -253,6 +253,13 @@ namespace Tensorflow | |||||
| return (int[]) dims.Clone(); | return (int[]) dims.Clone(); | ||||
| } | } | ||||
| public long[] as_list_long() | |||||
| { | |||||
| if (shape.IsEmpty) | |||||
| throw new ValueError("as_list() is not defined on an unknown TensorShape."); | |||||
| return dims.Select(x => Convert.ToInt64(x)).ToArray(); | |||||
| } | |||||
| public int num_elements() | public int num_elements() | ||||
| { | { | ||||
| if(is_fully_defined()) | if(is_fully_defined()) | ||||
| @@ -56,6 +56,9 @@ namespace Tensorflow | |||||
| public static implicit operator Tensors(Tensor[] tensors) | public static implicit operator Tensors(Tensor[] tensors) | ||||
| => new Tensors(tensors); | => new Tensors(tensors); | ||||
| public static implicit operator Tensors(List<Tensor> tensors) | |||||
| => new Tensors(tensors.ToArray()); | |||||
| public static implicit operator Tensor(Tensors tensors) | public static implicit operator Tensor(Tensors tensors) | ||||
| => tensors.FirstOrDefault(); | => tensors.FirstOrDefault(); | ||||
| @@ -100,6 +100,18 @@ namespace Tensorflow | |||||
| variable_accessed(this); | variable_accessed(this); | ||||
| var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | var result = gen_resource_variable_ops.read_variable_op(handle, _dtype); | ||||
| // _maybe_set_handle_data(_dtype, _handle, result); | // _maybe_set_handle_data(_dtype, _handle, result); | ||||
| // have to set shape when converting to substituent placeholder | |||||
| if (result.TensorShape.ndim == -1) | |||||
| { | |||||
| c_api.TF_GraphSetTensorShape(result.graph, | |||||
| result._as_tf_output(), | |||||
| shape.as_list_long(), | |||||
| shape.ndim, | |||||
| tf.Status.Handle); | |||||
| tf.Status.Check(true); | |||||
| } | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -160,15 +172,12 @@ namespace Tensorflow | |||||
| { | { | ||||
| } | } | ||||
| public Tensor AsTensor(bool as_ref = true) | |||||
| public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) | |||||
| { | { | ||||
| if (!as_ref && GraphElement != null) | |||||
| return GraphElement; | |||||
| if (as_ref) | if (as_ref) | ||||
| return tf.executing_eagerly() ? read_value() : GraphElement; | |||||
| return read_value().op.inputs[0]; | |||||
| else | else | ||||
| return _read_variable_op(); | |||||
| return value(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -49,6 +49,6 @@ namespace Tensorflow | |||||
| public TensorShape shape { get; } | public TensorShape shape { get; } | ||||
| Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign_add<T>(T delta, bool use_locking = false, string name = null, bool read_value = true); | ||||
| Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | Tensor assign<T>(T value, bool use_locking = false, string name = null, bool read_value = true); | ||||
| Tensor AsTensor(bool as_ref = true); | |||||
| Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false); | |||||
| } | } | ||||
| } | } | ||||
| @@ -222,7 +222,7 @@ namespace Tensorflow | |||||
| public Tensor value() => _snapshot; | public Tensor value() => _snapshot; | ||||
| public Tensor AsTensor(bool as_ref = true) => _snapshot; | |||||
| public Tensor AsTensor(TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool as_ref = false) => _snapshot; | |||||
| public Tensor _as_graph_element() => _variable; | public Tensor _as_graph_element() => _variable; | ||||
| @@ -37,7 +37,7 @@ namespace Tensorflow | |||||
| if (as_ref) | if (as_ref) | ||||
| return handle; | return handle; | ||||
| else | else | ||||
| return tf.executing_eagerly() ? AsTensor() : value(); | |||||
| return AsTensor(); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||