| @@ -31,6 +31,6 @@ namespace Tensorflow | |||||
| public Tensor reshape(Tensor tensor, | public Tensor reshape(Tensor tensor, | ||||
| object[] shape, | object[] shape, | ||||
| string name = null) | string name = null) | ||||
| => gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name); | |||||
| => array_ops.reshape(tensor, shape, name); | |||||
| } | } | ||||
| } | } | ||||
| @@ -23,7 +23,7 @@ namespace Tensorflow | |||||
| => gen_array_ops.tile(input, multiples, name); | => gen_array_ops.tile(input, multiples, name); | ||||
| public Tensor tile(Tensor input, object[] multiples, string name = null) | public Tensor tile(Tensor input, object[] multiples, string name = null) | ||||
| => gen_array_ops.tile(input, ops.convert_to_tensor(multiples), name); | |||||
| => array_ops.tile(input, multiples, name); | |||||
| public Tensor tile(Tensor input, Shape multiples, string name = null) | public Tensor tile(Tensor input, Shape multiples, string name = null) | ||||
| { | { | ||||
| @@ -5,4 +5,5 @@ global using System.Collections; | |||||
| global using System.Data; | global using System.Data; | ||||
| global using System.Linq; | global using System.Linq; | ||||
| global using Tensorflow.Keras.Engine; | global using Tensorflow.Keras.Engine; | ||||
| global using Tensorflow.Framework.Models; | |||||
| global using Tensorflow.Framework.Models; | |||||
| global using static Tensorflow.Binding; | |||||
| @@ -30,21 +30,32 @@ public class KerasTensor | |||||
| public static KerasTensor from_tensor(Tensor tensor) | public static KerasTensor from_tensor(Tensor tensor) | ||||
| { | { | ||||
| var type_spec = tensor.ToTensorSpec(); | var type_spec = tensor.ToTensorSpec(); | ||||
| var kt = new KerasTensor(type_spec, name: tensor.name); | |||||
| Shape? inferred_value = default; | |||||
| if (tensor.dtype == TF_DataType.TF_INT32 && tensor.rank < 2) | |||||
| { | |||||
| inferred_value = tf.ones(tensor).shape; | |||||
| } | |||||
| var kt = new KerasTensor(type_spec, inferred_value: inferred_value, name: tensor.name); | |||||
| kt.original_tensors = tensor; | kt.original_tensors = tensor; | ||||
| return kt; | return kt; | ||||
| } | } | ||||
| public KerasTensor this[int idx] | |||||
| => _original_tensors.First()[idx]; | |||||
| public KerasTensor this[params Slice[] slices] | |||||
| => _original_tensors.First()[slices]; | |||||
| public override string ToString() | public override string ToString() | ||||
| => _original_tensors.Length switch | => _original_tensors.Length switch | ||||
| { | { | ||||
| > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype}")) + "]", | |||||
| 1 => $"KerasTensor: shape={_original_tensors.shape} {GetInferredValueString()} dtype={_original_tensors.dtype}", | |||||
| > 1 => "[" + string.Join(", ", _original_tensors.Select(x => $"KerasTensor: shape={x.shape} dtype={x.dtype.as_numpy_name()}{GetInferredValueString()}")) + "]", | |||||
| 1 => $"KerasTensor: shape={_original_tensors.shape} dtype={_original_tensors.dtype.as_numpy_name()}{GetInferredValueString()}", | |||||
| _ => _original_tensors.ToString(), | _ => _original_tensors.ToString(), | ||||
| }; | }; | ||||
| private string GetInferredValueString() | private string GetInferredValueString() | ||||
| => _inferred_value == null ? "" : ""; | |||||
| => _inferred_value == null ? "" : $" inferred_value={_inferred_value}"; | |||||
| public static implicit operator Tensors(KerasTensor kt) | public static implicit operator Tensors(KerasTensor kt) | ||||
| => kt._original_tensors; | => kt._original_tensors; | ||||
| @@ -137,7 +137,7 @@ namespace Tensorflow | |||||
| if(shape.Length > 1) | if(shape.Length > 1) | ||||
| { | { | ||||
| shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); | shapeTensor = ops.convert_to_tensor(shape, dtypes.int32); | ||||
| if(shapeTensor.ndim > 1) | |||||
| if (shapeTensor.ndim > 1) | |||||
| { | { | ||||
| shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); | shapeTensor = array_ops.reshape(shapeTensor, new Shape(-1)); | ||||
| } | } | ||||
| @@ -304,6 +304,10 @@ namespace Tensorflow | |||||
| { | { | ||||
| elems_as_tensors.Add(tensor); | elems_as_tensors.Add(tensor); | ||||
| } | } | ||||
| else if (elem is KerasTensor kt) | |||||
| { | |||||
| elems_as_tensors.Add(kt); | |||||
| } | |||||
| else | else | ||||
| { | { | ||||
| var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); | var elem_tensor = constant_op.constant(elem, dtype: dtype, name: i.ToString()); | ||||
| @@ -404,7 +408,10 @@ namespace Tensorflow | |||||
| => gen_array_ops.reshape(tensor, shape, name: name); | => gen_array_ops.reshape(tensor, shape, name: name); | ||||
| public static Tensor reshape(Tensor tensor, object[] shape, string name = null) | public static Tensor reshape(Tensor tensor, object[] shape, string name = null) | ||||
| => gen_array_ops.reshape(tensor, ops.convert_to_tensor(shape), name: name); | |||||
| { | |||||
| var dims = shape_utils.from_object_array(shape); | |||||
| return gen_array_ops.reshape(tensor, dims, name: name); | |||||
| } | |||||
| private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | private static Tensor ones_like_impl<T>(T tensor, TF_DataType dtype, string name, bool optimize = true) | ||||
| { | { | ||||
| @@ -425,6 +432,10 @@ namespace Tensorflow | |||||
| return tf_with(ops.name_scope(name, "ones", new { shape }), scope => | return tf_with(ops.name_scope(name, "ones", new { shape }), scope => | ||||
| { | { | ||||
| name = scope; | name = scope; | ||||
| if (shape._shape_tuple().Length == 0) | |||||
| { | |||||
| shape = reshape(shape, new Shape(-1)); | |||||
| } | |||||
| var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | var output = gen_array_ops.fill(shape, constant_op.constant(1.0f, dtype: dtype), name: name); | ||||
| return output; | return output; | ||||
| }); | }); | ||||
| @@ -647,6 +658,20 @@ namespace Tensorflow | |||||
| } | } | ||||
| }); | }); | ||||
| public static Tensor tile(Tensor input, object[] multiples, string name = null) | |||||
| { | |||||
| Shape dims = shape_utils.from_object_array(multiples); | |||||
| return tf.Context.ExecuteOp("Tile", name, new ExecuteOpArgs(input, dims) | |||||
| { | |||||
| GetGradientAttrs = (op) => new | |||||
| { | |||||
| T = op.get_attr<TF_DataType>("T"), | |||||
| Tmultiples = op.get_attr<TF_DataType>("Tmultiples") | |||||
| } | |||||
| }); | |||||
| } | |||||
| public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | public static Tensor zeros_like(Tensor tensor, TF_DataType dtype = TF_DataType.DtInvalid, string name = null, bool optimize = true) | ||||
| { | { | ||||
| return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => | return tf_with(ops.name_scope(name, "zeros_like", new Tensor[] { tensor }), scope => | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Eager; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -13,5 +14,31 @@ namespace Tensorflow | |||||
| throw new NotImplementedException(""); | throw new NotImplementedException(""); | ||||
| } | } | ||||
| public static Shape from_object_array(object[] shape) | |||||
| { | |||||
| var dims = shape.Select(x => | |||||
| { | |||||
| if (x is KerasTensor kt && kt.inferred_value != null) | |||||
| { | |||||
| return kt.inferred_value.as_int_list()[0]; | |||||
| } | |||||
| else if (x is EagerTensor et && et.dtype == TF_DataType.TF_INT32) | |||||
| { | |||||
| return et.ToArray<int>()[0]; | |||||
| } | |||||
| else if (x is int i) | |||||
| { | |||||
| return i; | |||||
| } | |||||
| else if (x is long l) | |||||
| { | |||||
| return l; | |||||
| } | |||||
| throw new NotImplementedException(); | |||||
| }).ToArray(); | |||||
| return new Shape(dims); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -46,6 +46,9 @@ namespace Tensorflow | |||||
| public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | public Tensor ones(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | ||||
| => array_ops.ones(shape, dtype, name); | => array_ops.ones(shape, dtype, name); | ||||
| public Tensor ones(Tensor shape, TF_DataType dtype = TF_DataType.TF_FLOAT, string name = null) | |||||
| => array_ops.ones(shape, dtype, name); | |||||
| public Tensor size(Tensor input, | public Tensor size(Tensor input, | ||||
| string name = null, | string name = null, | ||||
| TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, | TF_DataType out_type = TF_DataType.TF_INT32) => array_ops.size(input, | ||||
| @@ -144,11 +144,18 @@ namespace Tensorflow | |||||
| } | } | ||||
| if (!graph.building_function) | if (!graph.building_function) | ||||
| { | { | ||||
| throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | |||||
| // return eager_tensor.AsPlaceholder(name: name); | |||||
| // throw new RuntimeError("Attempting to capture an EagerTensor without building a function."); | |||||
| return eager_tensor.AsPlaceholder(name: name); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| else if (value is KerasTensor kt) | |||||
| { | |||||
| if (kt.inferred_value != null) | |||||
| { | |||||
| return convert_to_tensor(kt.inferred_value, dtype: kt.dtype, name: name); | |||||
| } | |||||
| } | |||||
| // graph mode | // graph mode | ||||
| Tensor ret = value switch | Tensor ret = value switch | ||||
| @@ -141,7 +141,7 @@ Keras is an API designed for human beings, not machines. Keras follows best prac | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="HDF5-CSharp" Version="1.17.0" /> | <PackageReference Include="HDF5-CSharp" Version="1.17.0" /> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" /> | |||||
| <PackageReference Include="SharpZipLib" Version="1.4.2" /> | <PackageReference Include="SharpZipLib" Version="1.4.2" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||
| @@ -41,8 +41,8 @@ | |||||
| <ItemGroup> | <ItemGroup> | ||||
| <PackageReference Include="FluentAssertions" Version="5.10.3" /> | <PackageReference Include="FluentAssertions" Version="5.10.3" /> | ||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" /> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.3.2" /> | |||||
| <PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.149" /> | |||||
| <PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.6.3" /> | |||||
| <PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | <PackageReference Include="MSTest.TestAdapter" Version="2.2.10" /> | ||||
| <PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | <PackageReference Include="MSTest.TestFramework" Version="2.2.10" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||