Browse Source

NewOperation

tags/v0.1.0-Tensor
Oceania2018 6 years ago
parent
commit
458b3965ba
5 changed files with 52 additions and 18 deletions
  1. +0
    -0
      src/TensorFlowNET.Core/APIs/c_api.cs
  2. +6
    -1
      src/TensorFlowNET.Core/APIs/tf.constant.cs
  3. +38
    -0
      src/TensorFlowNET.Core/Graphs/Graph.Operation.cs
  4. +0
    -15
      src/TensorFlowNET.Core/Graphs/Graph.cs
  5. +8
    -2
      src/TensorFlowNET.Core/Tensors/Tensor.cs

src/TensorFlowNET.Core/c_api.cs → src/TensorFlowNET.Core/APIs/c_api.cs View File


src/TensorFlowNET.Core/Tensors/tf.constant.cs → src/TensorFlowNET.Core/APIs/tf.constant.cs View File

@@ -9,7 +9,12 @@ namespace Tensorflow
{
public static Tensor constant(NDArray nd, string name = "Const", bool verify_shape = false)
{
return constant_op.Create(nd, name, verify_shape);
//constant_op.Create(nd, name, verify_shape);
var graph = tf.get_default_graph();
var tensor = new Tensor(nd);
var op = graph.NewOperation("Const", name, tensor);
return null;
}
}
}

+ 38
- 0
src/TensorFlowNET.Core/Graphs/Graph.Operation.cs View File

@@ -0,0 +1,38 @@
using NumSharp.Core;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow
{
public partial class Graph
{
public Operation NewOperation(string opType, string opName, Tensor tensor)
{
var desc = c_api.TF_NewOperation(_handle, opType, opName);

if (tensor.dtype == TF_DataType.TF_STRING)
{
var value = "Hello World!";
var bytes = Encoding.UTF8.GetBytes(value);
var buf = Marshal.AllocHGlobal(bytes.Length + 1);
Marshal.Copy(bytes, 0, buf, bytes.Length);
c_api.TF_SetAttrString(desc, "value", buf, (uint)value.Length);
}
else
{
c_api.TF_SetAttrTensor(desc, "value", tensor, Status);
}
Status.Check();

c_api.TF_SetAttrType(desc, "dtype", tensor.dtype);

var op = c_api.TF_FinishOperation(desc, Status);
Status.Check();

return op;
}
}
}

+ 0
- 15
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -41,21 +41,6 @@ namespace Tensorflow
_names_in_use = new Dictionary<string, int>();
}

public Operation NewOperation(string opType, string opName, Tensor t)
{
var desc = c_api.TF_NewOperation(_handle, opType, opName);
c_api.TF_SetAttrTensor(desc, "value", t, Status);
Status.Check();

c_api.TF_SetAttrType(desc, "dtype", t.dtype);

var op = c_api.TF_FinishOperation(desc, Status);
Status.Check();

return op;
}

public T as_graph_element<T>(T obj, bool allow_tensor = true, bool allow_operation = true)
{
return _as_graph_element_locked(obj, allow_tensor, allow_operation);


+ 8
- 2
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -83,8 +83,14 @@ namespace Tensorflow
Marshal.Copy(nd.Data<double>(), 0, dotHandle, nd.size);
break;
case "String":
dotHandle = Marshal.StringToHGlobalAuto(nd.Data<string>()[0]);
size = (ulong)nd.Data<string>()[0].Length;
var value = nd.Data<string>()[0];
var bytes = Encoding.UTF8.GetBytes(value);
var buf = Marshal.AllocHGlobal(bytes.Length + 1);
Marshal.Copy(bytes, 0, buf, bytes.Length);

//c_api.TF_SetAttrString(op, "value", buf, (uint)bytes.Length);

size = (ulong)bytes.Length;
break;
default:
throw new NotImplementedException("Marshal.Copy failed.");


Loading…
Cancel
Save