Browse Source

Hello World works.

tags/v0.12
Oceania2018 6 years ago
parent
commit
4ca080e565
6 changed files with 18 additions and 9 deletions
  1. +1
    -1
      src/TensorFlowNET.Core/Sessions/BaseSession.cs
  2. +0
    -1
      src/TensorFlowNET.Core/TensorFlowNET.Core.csproj
  3. +5
    -4
      src/TensorFlowNET.Core/Tensors/dtypes.cs
  4. +1
    -1
      src/TensorFlowNET.Core/Tensors/tensor_util.cs
  5. +8
    -0
      src/TensorFlowNET.Core/Tensors/tf.constant.cs
  6. +3
    -2
      test/TensorFlowNET.Examples/HelloWorld.cs

+ 1
- 1
src/TensorFlowNET.Core/Sessions/BaseSession.cs View File

@@ -302,7 +302,7 @@ namespace Tensorflow
// wired, don't know why we have to start from offset 9.
// length in the begin
var str = UTF8Encoding.Default.GetString(bytes, 9, bytes[8]);
nd = np.array(str).reshape();
nd = np.array(str);
break;
case TF_DataType.TF_UINT8:
var _bytes = new byte[tensor.size];


+ 0
- 1
src/TensorFlowNET.Core/TensorFlowNET.Core.csproj View File

@@ -63,7 +63,6 @@ Docs: https://tensorflownet.readthedocs.io</Description>

<ItemGroup>
<PackageReference Include="Google.Protobuf" Version="3.9.0" />
<PackageReference Include="System.Runtime.CompilerServices.Unsafe" Version="4.5.2" />
</ItemGroup>

<ItemGroup>


+ 5
- 4
src/TensorFlowNET.Core/Tensors/dtypes.cs View File

@@ -51,12 +51,13 @@ namespace Tensorflow
}

// "sbyte", "byte", "short", "ushort", "int", "uint", "long", "ulong", "float", "double", "Complex"
public static TF_DataType as_dtype(Type type)
public static TF_DataType as_dtype(Type type, TF_DataType? dtype = null)
{
TF_DataType dtype = TF_DataType.DtInvalid;
switch (type.Name)
{
case "Char":
dtype = dtype ?? TF_DataType.TF_UINT8;
break;
case "SByte":
dtype = TF_DataType.TF_INT8;
break;
@@ -100,7 +101,7 @@ namespace Tensorflow
throw new Exception("as_dtype Not Implemented");
}

return dtype;
return dtype.Value;
}

public static DataType as_datatype_enum(this TF_DataType type)


+ 1
- 1
src/TensorFlowNET.Core/Tensors/tensor_util.cs View File

@@ -226,7 +226,7 @@ namespace Tensorflow
}
}

var numpy_dtype = dtypes.as_dtype(nparray.dtype);
var numpy_dtype = dtypes.as_dtype(nparray.dtype, dtype: dtype);
if (numpy_dtype == TF_DataType.DtInvalid)
throw new TypeError($"Unrecognized data type: {nparray.dtype}");



+ 8
- 0
src/TensorFlowNET.Core/Tensors/tf.constant.cs View File

@@ -33,6 +33,14 @@ namespace Tensorflow
verify_shape: verify_shape,
allow_broadcast: false);

public static Tensor constant(string value,
string name = "Const") => constant_op._constant_impl(value,
tf.@string,
new int[] { 1 },
name,
verify_shape: false,
allow_broadcast: false);

public static Tensor constant(float value,
int shape,
string name = "Const") => constant_op._constant_impl(value,


+ 3
- 2
test/TensorFlowNET.Examples/HelloWorld.cs View File

@@ -29,8 +29,9 @@ namespace TensorFlowNET.Examples
{
// Run the op
var result = sess.run(hello);
Console.WriteLine(result.ToString());
return result.ToString().Equals(str);
string result_string = string.Join("", result.GetData<char>());
Console.WriteLine(result_string);
return result_string.Equals(str);
});
}



Loading…
Cancel
Save