| @@ -24,7 +24,7 @@ namespace Tensorflow | |||||
| public partial class Operation | public partial class Operation | ||||
| { | { | ||||
| public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | public int NumOutputs => c_api.TF_OperationNumOutputs(_handle); | ||||
| public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(new TF_Output(_handle, index)); | |||||
| public TF_DataType OutputType(int index) => c_api.TF_OperationOutputType(_tf_output(index)); | |||||
| public int OutputListLength(string name) | public int OutputListLength(string name) | ||||
| { | { | ||||
| @@ -44,7 +44,6 @@ namespace Tensorflow | |||||
| public partial class Operation : ITensorOrOperation | public partial class Operation : ITensorOrOperation | ||||
| { | { | ||||
| private readonly IntPtr _handle; // _c_op in python | private readonly IntPtr _handle; // _c_op in python | ||||
| private readonly IntPtr _operDesc; | |||||
| private readonly Graph _graph; | private readonly Graph _graph; | ||||
| private NodeDef _node_def; | private NodeDef _node_def; | ||||
| @@ -91,7 +90,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| _graph = g; | _graph = g; | ||||
| _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
| var _operDesc = c_api.TF_NewOperation(g, opType, oper_name); | |||||
| c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | c_api.TF_SetAttrType(_operDesc, "dtype", TF_DataType.TF_INT32); | ||||
| lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
| using (var status = new Status()) | using (var status = new Status()) | ||||
| @@ -161,7 +160,7 @@ namespace Tensorflow | |||||
| op_def = g.GetOpDef(node_def.Op); | op_def = g.GetOpDef(node_def.Op); | ||||
| var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | var grouped_inputs = _reconstruct_sequence_inputs(op_def, inputs, node_def.Attr); | ||||
| (_handle, _operDesc) = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| _handle = ops._create_c_op(g, node_def, grouped_inputs, control_input_ops.ToArray()); | |||||
| // Initialize self._outputs. | // Initialize self._outputs. | ||||
| output_types = new TF_DataType[NumOutputs]; | output_types = new TF_DataType[NumOutputs]; | ||||
| @@ -170,7 +169,7 @@ namespace Tensorflow | |||||
| _outputs = new Tensor[NumOutputs]; | _outputs = new Tensor[NumOutputs]; | ||||
| for (int i = 0; i < NumOutputs; i++) | for (int i = 0; i < NumOutputs; i++) | ||||
| _outputs[i] = new Tensor(this, i, OutputType(i)); | |||||
| _outputs[i] = new Tensor(this, i, output_types[i]); | |||||
| graph._add_op(this); | graph._add_op(this); | ||||
| @@ -275,7 +275,7 @@ namespace Tensorflow | |||||
| /// </returns> | /// </returns> | ||||
| public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | public static Tensor[] _SwitchRefOrTensor(Tensor data, Tensor pred, string name = "Switch") | ||||
| { | { | ||||
| data = ops.convert_to_tensor_or_indexed_slices(data, name: "data"); | |||||
| data = ops.convert_to_tensor_or_composite(data, name: "data"); | |||||
| // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | // NOTE(vrv): ops.colocate_with(data, ignore_existing=True) below | ||||
| // addresses the following scenario. | // addresses the following scenario. | ||||
| // | // | ||||
| @@ -296,9 +296,8 @@ namespace Tensorflow | |||||
| { | { | ||||
| if (data is Tensor) | if (data is Tensor) | ||||
| { | { | ||||
| // TODO: ref_switch | |||||
| //if (data.dtype._is_ref_dtype) | |||||
| // return control_flow_ops.ref_switch(data, pred, name = name); | |||||
| if (data.dtype.is_ref_dtype()) | |||||
| return gen_control_flow_ops.ref_switch(data, pred, name: name); | |||||
| } | } | ||||
| return @switch(data, pred, name: name); | return @switch(data, pred, name: name); | ||||
| } | } | ||||
| @@ -114,6 +114,12 @@ namespace Tensorflow | |||||
| return _op; | return _op; | ||||
| } | } | ||||
| public static Tensor[] ref_switch(Tensor data, Tensor pred, string name = null) | |||||
| { | |||||
| var _op = _op_def_lib._apply_op_helper("RefSwitch", name, new { data, pred }); | |||||
| return _op.outputs; | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Forwards `data` to the output port determined by `pred`. | /// Forwards `data` to the output port determined by `pred`. | ||||
| /// | /// | ||||
| @@ -5,7 +5,7 @@ | |||||
| <AssemblyName>TensorFlow.NET</AssemblyName> | <AssemblyName>TensorFlow.NET</AssemblyName> | ||||
| <RootNamespace>Tensorflow</RootNamespace> | <RootNamespace>Tensorflow</RootNamespace> | ||||
| <TargetTensorFlow>1.14.0</TargetTensorFlow> | <TargetTensorFlow>1.14.0</TargetTensorFlow> | ||||
| <Version>0.11.2</Version> | |||||
| <Version>0.11.3</Version> | |||||
| <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | <Authors>Haiping Chen, Meinrad Recheis, Eli Belash</Authors> | ||||
| <Company>SciSharp STACK</Company> | <Company>SciSharp STACK</Company> | ||||
| <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | <GeneratePackageOnBuild>true</GeneratePackageOnBuild> | ||||
| @@ -17,7 +17,7 @@ | |||||
| <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | <PackageTags>TensorFlow, NumSharp, SciSharp, MachineLearning, TensorFlow.NET, C#</PackageTags> | ||||
| <Description>Google's TensorFlow full binding in .NET Standard. | <Description>Google's TensorFlow full binding in .NET Standard. | ||||
| Docs: https://tensorflownet.readthedocs.io</Description> | Docs: https://tensorflownet.readthedocs.io</Description> | ||||
| <AssemblyVersion>0.11.2.0</AssemblyVersion> | |||||
| <AssemblyVersion>0.11.3.0</AssemblyVersion> | |||||
| <PackageReleaseNotes>Changes since v0.10.0: | <PackageReleaseNotes>Changes since v0.10.0: | ||||
| 1. Upgrade NumSharp to v0.20. | 1. Upgrade NumSharp to v0.20. | ||||
| 2. Add DisposableObject class to manage object lifetime. | 2. Add DisposableObject class to manage object lifetime. | ||||
| @@ -30,7 +30,7 @@ Docs: https://tensorflownet.readthedocs.io</Description> | |||||
| 9. MultiThread is safe. | 9. MultiThread is safe. | ||||
| 10. Support n-dim indexing for tensor.</PackageReleaseNotes> | 10. Support n-dim indexing for tensor.</PackageReleaseNotes> | ||||
| <LangVersion>7.3</LangVersion> | <LangVersion>7.3</LangVersion> | ||||
| <FileVersion>0.11.2.0</FileVersion> | |||||
| <FileVersion>0.11.3.0</FileVersion> | |||||
| <PackageLicenseFile>LICENSE</PackageLicenseFile> | <PackageLicenseFile>LICENSE</PackageLicenseFile> | ||||
| <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | <PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance> | ||||
| <SignAssembly>true</SignAssembly> | <SignAssembly>true</SignAssembly> | ||||
| @@ -117,7 +117,7 @@ namespace Tensorflow | |||||
| /// Key to collect Variable objects that are global (shared across machines). | /// Key to collect Variable objects that are global (shared across machines). | ||||
| /// Default collection for all variables, except local ones. | /// Default collection for all variables, except local ones. | ||||
| /// </summary> | /// </summary> | ||||
| public string GLOBAL_VARIABLES = GLOBAL_VARIABLES_; | |||||
| public string GLOBAL_VARIABLES => GLOBAL_VARIABLES_; | |||||
| public string TRAIN_OP => TRAIN_OP_; | public string TRAIN_OP => TRAIN_OP_; | ||||
| @@ -206,7 +206,7 @@ namespace Tensorflow | |||||
| /// </param> | /// </param> | ||||
| /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | /// <param name="control_inputs">A list of `Operation`s to set as control dependencies.</param> | ||||
| /// <returns>A wrapped TF_Operation*.</returns> | /// <returns>A wrapped TF_Operation*.</returns> | ||||
| public static (IntPtr, IntPtr) _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
| public static IntPtr _create_c_op<T>(Graph graph, NodeDef node_def, T[] inputs, Operation[] control_inputs) | |||||
| { | { | ||||
| lock (Locks.ProcessWide) | lock (Locks.ProcessWide) | ||||
| { | { | ||||
| @@ -249,7 +249,7 @@ namespace Tensorflow | |||||
| status.Check(true); | status.Check(true); | ||||
| return (c_op, op_desc); | |||||
| return c_op; | |||||
| } | } | ||||
| } | } | ||||
| @@ -27,7 +27,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| using (var g = tf.Graph().as_default()) | using (var g = tf.Graph().as_default()) | ||||
| { | { | ||||
| var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | var x = constant_op.constant(new[,] {{1, 2, 3}, {4, 5, 6}}); | ||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
| var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "myop"), new[] {x}, new Operation[0]); | |||||
| var op = g._create_op_from_tf_operation(c_op); | var op = g._create_op_from_tf_operation(c_op); | ||||
| Assert.AreEqual("myop", op.name); | Assert.AreEqual("myop", op.name); | ||||
| @@ -68,7 +68,7 @@ namespace TensorFlowNET.UnitTest.ops_test | |||||
| var true_fn = new Func<Tensor>(() => | var true_fn = new Func<Tensor>(() => | ||||
| { | { | ||||
| var (c_op, op_desc) = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||||
| var c_op = ops._create_c_op(g, ops._NodeDef("Identity", "cond/myop"), new[] { x }, new Operation[0]); | |||||
| var new_ops = g._add_new_tf_operations(); | var new_ops = g._add_new_tf_operations(); | ||||
| self.assertEqual(len(new_ops), 1); | self.assertEqual(len(new_ops), 1); | ||||
| return x; | return x; | ||||