|
|
@@ -47,16 +47,23 @@ namespace Tensorflow |
|
|
var tensor_dtype = tensor.Dtype.as_numpy_dtype(); |
|
|
var tensor_dtype = tensor.Dtype.as_numpy_dtype(); |
|
|
|
|
|
|
|
|
if (tensor.TensorContent.Length > 0) |
|
|
if (tensor.TensorContent.Length > 0) |
|
|
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype) |
|
|
|
|
|
.reshape(shape); |
|
|
|
|
|
|
|
|
{ |
|
|
|
|
|
return np.frombuffer(tensor.TensorContent.ToByteArray(), tensor_dtype).reshape(shape); |
|
|
|
|
|
} |
|
|
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) |
|
|
else if (tensor.Dtype == DataType.DtHalf || tensor.Dtype == DataType.DtBfloat16) |
|
|
; |
|
|
; |
|
|
else if (tensor.Dtype == DataType.DtFloat) |
|
|
else if (tensor.Dtype == DataType.DtFloat) |
|
|
; |
|
|
; |
|
|
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) |
|
|
else if (new DataType[] { DataType.DtInt32, DataType.DtUint8 }.Contains(tensor.Dtype)) |
|
|
|
|
|
{ |
|
|
if (tensor.IntVal.Count == 1) |
|
|
if (tensor.IntVal.Count == 1) |
|
|
return np.repeat(np.array(tensor.IntVal[0]), Convert.ToInt32(num_elements)) |
|
|
|
|
|
.reshape(shape); |
|
|
|
|
|
|
|
|
return np.repeat(np.array(tensor.IntVal[0]), num_elements).reshape(shape); |
|
|
|
|
|
} |
|
|
|
|
|
else if (tensor.Dtype == DataType.DtBool) |
|
|
|
|
|
{ |
|
|
|
|
|
if (tensor.BoolVal.Count == 1) |
|
|
|
|
|
return np.repeat(np.array(tensor.BoolVal[0]), num_elements).reshape(shape); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
throw new NotImplementedException("MakeNdarray"); |
|
|
throw new NotImplementedException("MakeNdarray"); |
|
|
} |
|
|
} |
|
|
|