| @@ -0,0 +1,14 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| namespace Tensorflow | |||||
| { | |||||
| public static partial class tf | |||||
| { | |||||
| public static unsafe Tensor matmul(Tensor a, Tensor b) | |||||
| { | |||||
| return gen_math_ops.mat_mul(a, b); | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -78,6 +78,9 @@ namespace Tensorflow | |||||
| case "type": | case "type": | ||||
| attr_value.Type = _MakeType((TF_DataType)value, attr_def); | attr_value.Type = _MakeType((TF_DataType)value, attr_def); | ||||
| break; | break; | ||||
| case "bool": | |||||
| attr_value.B = (bool)value; | |||||
| break; | |||||
| case "shape": | case "shape": | ||||
| attr_value.Shape = new TensorShapeProto(); | attr_value.Shape = new TensorShapeProto(); | ||||
| break; | break; | ||||
| @@ -30,5 +30,18 @@ namespace Tensorflow | |||||
| return new Tensor(_op, 0, _op.OutputType(0)); | return new Tensor(_op, 0, _op.OutputType(0)); | ||||
| } | } | ||||
| public static Tensor mat_mul(Tensor a, Tensor b, bool transpose_a = false, bool transpose_b = false) | |||||
| { | |||||
| var keywords = new Dictionary<string, object>(); | |||||
| keywords.Add("a", a); | |||||
| keywords.Add("b", b); | |||||
| keywords.Add("transpose_a", transpose_a); | |||||
| keywords.Add("transpose_b", transpose_b); | |||||
| var _op = _op_def_lib._apply_op_helper("MatMul", name: "MatMul", keywords: keywords); | |||||
| return new Tensor(_op, 0, _op.OutputType(0)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using Microsoft.VisualStudio.TestTools.UnitTesting; | using Microsoft.VisualStudio.TestTools.UnitTesting; | ||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow; | using Tensorflow; | ||||
| @@ -63,8 +64,9 @@ namespace TensorFlowNET.UnitTest | |||||
| [TestMethod] | [TestMethod] | ||||
| public void String() | public void String() | ||||
| { | { | ||||
| //var desc = init("string"); | |||||
| //c_api.TF_SetAttrString(desc, "v", "bunny", 5); | |||||
| var desc = init("string"); | |||||
| var handle = Marshal.StringToHGlobalAnsi("bunny"); | |||||
| c_api.TF_SetAttrString(desc, "v", handle, 5); | |||||
| //var oper = c_api.TF_FinishOperation(desc, s_); | //var oper = c_api.TF_FinishOperation(desc, s_); | ||||
| //ASSERT_EQ(TF_Code.TF_OK, s_.Code); | //ASSERT_EQ(TF_Code.TF_OK, s_.Code); | ||||
| @@ -23,6 +23,7 @@ | |||||
| </ItemGroup> | </ItemGroup> | ||||
| <ItemGroup> | <ItemGroup> | ||||
| <ProjectReference Include="..\..\..\NumSharp\src\NumSharp.Core\NumSharp.Core.csproj" /> | |||||
| <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | <ProjectReference Include="..\..\src\TensorFlowNET.Core\TensorFlowNET.Core.csproj" /> | ||||
| </ItemGroup> | </ItemGroup> | ||||