Browse Source

Add initial tests + fix bugs. Still WIP since the test is failing.

tags/v0.5.1
Mihai 2 years ago
parent
commit
0bd495276b
4 changed files with 161 additions and 38 deletions
  1. +118
    -0
      LLama.Unittest/GrammarParserTest.cs
  2. +39
    -36
      LLama/Grammar/GrammarParser.cs
  3. +1
    -1
      LLama/Grammar/ParseState.cs
  4. +3
    -1
      LLama/Native/LLamaGrammarElement.cs

+ 118
- 0
LLama.Unittest/GrammarParserTest.cs View File

@@ -0,0 +1,118 @@
using LLama.Common;
using LLama.Native;
using System.Diagnostics;
using LLama.Grammar;
using Newtonsoft.Json.Linq;

namespace LLama.Unittest
{
public sealed class GrammarParserTest
{
[Fact]
public void ParseComplexGrammar()
{
GrammarParser parsedGrammar = new GrammarParser();
string grammarBytes = @"root ::= (expr ""="" term ""\n"")+
expr ::= term ([-+*/] term)*
term ::= [0-9]+";

ParseState state = parsedGrammar.Parse(grammarBytes);

List<KeyValuePair<string, uint>> expected = new List<KeyValuePair<string, uint>>
{
new KeyValuePair<string, uint>("expr", 2),
new KeyValuePair<string, uint>("expr_5", 5),
new KeyValuePair<string, uint>("expr_6", 6),
new KeyValuePair<string, uint>("root", 0),
new KeyValuePair<string, uint>("root_1", 1),
new KeyValuePair<string, uint>("root_4", 4),
new KeyValuePair<string, uint>("term", 3),
new KeyValuePair<string, uint>("term_7", 7),
};

uint index = 0;
foreach (var it in state.SymbolIds)
{
string key = it.Key;
uint value = it.Value;
var expectedPair = expected[(int)index];

// pretty print error message before asserting
if (expectedPair.Key != key || expectedPair.Value != value)
{
Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}");
Console.Error.WriteLine($"actualPair: {key}, {value}");
Console.Error.WriteLine("expectedPair != actualPair");
}
Assert.Equal(expectedPair.Key, key);
Assert.Equal(expectedPair.Value, value);

index++;
}
Assert.NotEmpty(state.SymbolIds);


var expectedRules = new List<LLamaGrammarElement>
{
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7),
new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
};

index = 0;
foreach (var rule in state.Rules)
{
// compare rule to expected rule
for (uint i = 0; i < rule.Count; i++)
{
var element = rule[(int)i];
var expectedElement = expectedRules[(int)index];

// Pretty print error message before asserting
if (expectedElement.Type != element.Type || expectedElement.Value != element.Value)
{
Console.Error.WriteLine($"index: {index}");
Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}");
Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}");
Console.Error.WriteLine("expected_element != actual_element");
}
Assert.Equal(expectedElement.Type, element.Type);
Assert.Equal(expectedElement.Value, element.Value);
index++;
}
}
Assert.NotEmpty(state.Rules);
}
}
}

+ 39
- 36
LLama/Grammar/GrammarParser.cs View File

@@ -12,11 +12,11 @@ namespace LLama.Grammar
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
internal class GrammarParser
public class GrammarParser
{
// NOTE: assumes valid utf8 (but checks for overrun)
// copied from llama.cpp
public uint DecodeUTF8(ref ReadOnlySpan<byte> src)
private uint DecodeUTF8(ref ReadOnlySpan<byte> src)
{
int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 };

@@ -39,7 +39,7 @@ namespace LLama.Grammar
return value;
}

public uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
private uint GetSymbolId(ParseState state, ReadOnlySpan<byte> src, int len)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray());
@@ -55,7 +55,7 @@ namespace LLama.Grammar
}
}

public uint GenerateSymbolId(ParseState state, string baseName)
private uint GenerateSymbolId(ParseState state, string baseName)
{
uint nextId = (uint)state.SymbolIds.Count;
string key = $"{baseName}_{nextId}";
@@ -63,7 +63,7 @@ namespace LLama.Grammar
return nextId;
}

public void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule)
private void AddRule(ParseState state, uint ruleId, List<LLamaGrammarElement> rule)
{
while (state.Rules.Count <= ruleId)
{
@@ -73,12 +73,12 @@ namespace LLama.Grammar
state.Rules[(int)ruleId] = rule;
}

public bool IsWordChar(byte c)
private bool IsWordChar(byte c)
{
return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
}

public uint ParseHex(ref ReadOnlySpan<byte> src, int size)
private uint ParseHex(ref ReadOnlySpan<byte> src, int size)
{
int pos = 0;
int end = size;
@@ -114,7 +114,7 @@ namespace LLama.Grammar
return value;
}

public ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
private ReadOnlySpan<byte> ParseSpace(ReadOnlySpan<byte> src, bool newlineOk)
{
int pos = 0;
while (pos < src.Length &&
@@ -136,7 +136,7 @@ namespace LLama.Grammar
return src.Slice(pos);
}

public ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
private ReadOnlySpan<byte> ParseName(ReadOnlySpan<byte> src)
{
int pos = 0;
while (pos < src.Length && IsWordChar(src[pos]))
@@ -150,30 +150,31 @@ namespace LLama.Grammar
return src.Slice(pos);
}

public uint ParseChar(ref ReadOnlySpan<byte> src)
private uint ParseChar(ref ReadOnlySpan<byte> src)
{
if (src[0] == '\\')
{
var chr = src[1];
src = src.Slice(2);
switch ((char)src[1])
switch (chr)
{
case 'x':
case (byte)'x':
return ParseHex(ref src, 2);
case 'u':
case (byte)'u':
return ParseHex(ref src, 4);
case 'U':
case (byte)'U':
return ParseHex(ref src, 8);
case 't':
case (byte)'t':
return '\t';
case 'r':
case (byte)'r':
return '\r';
case 'n':
case (byte)'n':
return '\n';
case '\\':
case '"':
case '[':
case ']':
return src[1];
case (byte)'\\':
case (byte)'"':
case (byte)'[':
case (byte)']':
return chr;
default:
throw new Exception("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray()));
}
@@ -186,7 +187,7 @@ namespace LLama.Grammar
throw new Exception("Unexpected end of input");
}

public ReadOnlySpan<byte> ParseSequence(
private ReadOnlySpan<byte> ParseSequence(
ParseState state,
ReadOnlySpan<byte> pos,
string ruleName,
@@ -202,7 +203,7 @@ namespace LLama.Grammar
pos = pos.Slice(1);
lastSymStart = outElements.Count;

while (pos[0] != '"')
while (!pos.IsEmpty && pos[0] != '"')
{
var charPair = ParseChar(ref pos);
outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair });
@@ -222,7 +223,7 @@ namespace LLama.Grammar

lastSymStart = outElements.Count;

while (pos[0] != ']')
while (!pos.IsEmpty && pos[0] != ']')
{
var charPair = ParseChar(ref pos);
var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType;
@@ -315,7 +316,7 @@ namespace LLama.Grammar
return pos;
}

public ReadOnlySpan<byte> ParseAlternates(
private ReadOnlySpan<byte> ParseAlternates(
ParseState state,
ReadOnlySpan<byte> src,
string ruleName,
@@ -325,7 +326,7 @@ namespace LLama.Grammar
var rule = new List<LLamaGrammarElement>();
ReadOnlySpan<byte> pos = ParseSequence(state, src, ruleName, rule, isNested);

while (pos[0] == '|')
while (!pos.IsEmpty && pos[0] == '|')
{
rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 });
pos = ParseSpace(pos.Slice(1), true);
@@ -338,11 +339,11 @@ namespace LLama.Grammar
return pos;
}

public ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src)
private ReadOnlySpan<byte> ParseRule(ParseState state, ReadOnlySpan<byte> src)
{
ReadOnlySpan<byte> nameEnd = ParseName(src);
ReadOnlySpan<byte> pos = ParseSpace(nameEnd, false);
int nameLen = nameEnd.Length - src.Length;
int nameLen = src.Length - nameEnd.Length;
uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), nameLen);
string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray());

@@ -354,25 +355,27 @@ namespace LLama.Grammar

pos = ParseAlternates(state, pos, name, ruleId, false);

if (pos[0] == '\r')
if (!pos.IsEmpty && pos[0] == '\r')
{
pos = pos.Slice(pos[1] == '\n' ? 2 : 1);
}
else if (pos[0] == '\n')
else if (!pos.IsEmpty && pos[0] == '\n')
{
pos = pos.Slice(1);
}
else if (pos.Length > 0)
else if (!pos.IsEmpty)
{
throw new Exception($"Expecting newline or end at {Encoding.UTF8.GetString(pos.ToArray())}");
}
return ParseSpace(pos, true);
}

public ParseState Parse(ReadOnlySpan<byte> src)
public ParseState Parse(string input)
{
try
{
byte[] byteArray = Encoding.UTF8.GetBytes(input);
ReadOnlySpan<byte> src = new ReadOnlySpan<byte>(byteArray);
ParseState state = new ParseState();
ReadOnlySpan<byte> pos = ParseSpace(src, true);

@@ -390,7 +393,7 @@ namespace LLama.Grammar
}
}

public void PrintGrammarChar(StreamWriter file, uint c)
private void PrintGrammarChar(StreamWriter file, uint c)
{
if (c >= 0x20 && c <= 0x7F)
{
@@ -403,7 +406,7 @@ namespace LLama.Grammar
}
}

public bool IsCharElement(LLamaGrammarElement elem)
private bool IsCharElement(LLamaGrammarElement elem)
{
switch (elem.Type)
{
@@ -451,7 +454,7 @@ namespace LLama.Grammar
file.WriteLine();
}

public void PrintRule(
private void PrintRule(
StreamWriter file,
uint ruleId,
List<LLamaGrammarElement> rule,


+ 1
- 1
LLama/Grammar/ParseState.cs View File

@@ -10,7 +10,7 @@ namespace LLama.Grammar
///
/// The commit hash from URL is the actual commit hash that reflects current C# code.
/// </summary>
internal class ParseState
public class ParseState
{
public Dictionary<string, uint> SymbolIds { get; } = new Dictionary<string, uint>();
public List<List<LLamaGrammarElement>> Rules { get; } = new List<List<LLamaGrammarElement>>();


+ 3
- 1
LLama/Native/LLamaGrammarElement.cs View File

@@ -1,4 +1,5 @@
using System.Runtime.InteropServices;
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace LLama.Native
{
@@ -49,6 +50,7 @@ namespace LLama.Native
/// An element of a grammar
/// </summary>
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("{Type} {Value}")]
public struct LLamaGrammarElement
{
/// <summary>


Loading…
Cancel
Save