diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs new file mode 100644 index 00000000..926e51a9 --- /dev/null +++ b/LLama.Unittest/GrammarParserTest.cs @@ -0,0 +1,118 @@ +using LLama.Common; +using LLama.Native; +using System.Diagnostics; +using LLama.Grammar; +using Newtonsoft.Json.Linq; + +namespace LLama.Unittest +{ + public sealed class GrammarParserTest + { + [Fact] + public void ParseComplexGrammar() + { + GrammarParser parsedGrammar = new GrammarParser(); + string grammarBytes = @"root ::= (expr ""="" term ""\n"")+ + expr ::= term ([-+*/] term)* + term ::= [0-9]+"; + + ParseState state = parsedGrammar.Parse(grammarBytes); + + List> expected = new List> + { + new KeyValuePair("expr", 2), + new KeyValuePair("expr_5", 5), + new KeyValuePair("expr_6", 6), + new KeyValuePair("root", 0), + new KeyValuePair("root_1", 1), + new KeyValuePair("root_4", 4), + new KeyValuePair("term", 3), + new KeyValuePair("term_7", 7), + }; + + uint index = 0; + foreach (var it in state.SymbolIds) + { + string key = it.Key; + uint value = it.Value; + var expectedPair = expected[(int)index]; + + // pretty print error message before asserting + if (expectedPair.Key != key || expectedPair.Value != value) + { + Console.Error.WriteLine($"expectedPair: {expectedPair.Key}, {expectedPair.Value}"); + Console.Error.WriteLine($"actualPair: {key}, {value}"); + Console.Error.WriteLine("expectedPair != actualPair"); + } + Assert.Equal(expectedPair.Key, key); + Assert.Equal(expectedPair.Value, value); + + index++; + } + Assert.NotEmpty(state.SymbolIds); + + + var expectedRules = new List + { + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 2), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 61), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 10), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 43), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 42), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_ALT, 47), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 3), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 5), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 6), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 7), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 48), + new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 57), + new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + }; + + index = 0; + foreach (var rule in state.Rules) + { + // compare rule to expected rule + for (uint i = 0; i < rule.Count; i++) + { + var element = rule[(int)i]; + var expectedElement = expectedRules[(int)index]; + + // Pretty print error message before asserting + if (expectedElement.Type != element.Type || expectedElement.Value != element.Value) + { + Console.Error.WriteLine($"index: {index}"); + Console.Error.WriteLine($"expected_element: {expectedElement.Type}, {expectedElement.Value}"); + Console.Error.WriteLine($"actual_element: {element.Type}, {element.Value}"); + Console.Error.WriteLine("expected_element != actual_element"); + } + Assert.Equal(expectedElement.Type, element.Type); + Assert.Equal(expectedElement.Value, element.Value); + index++; + } + } + Assert.NotEmpty(state.Rules); + } + } +} diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs index 2a91a855..3e038992 100644 --- a/LLama/Grammar/GrammarParser.cs +++ b/LLama/Grammar/GrammarParser.cs @@ -12,11 +12,11 @@ namespace LLama.Grammar /// /// The commit hash from URL is the actual commit hash that reflects current C# code. /// - internal class GrammarParser + public class GrammarParser { // NOTE: assumes valid utf8 (but checks for overrun) // copied from llama.cpp - public uint DecodeUTF8(ref ReadOnlySpan src) + private uint DecodeUTF8(ref ReadOnlySpan src) { int[] lookup = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; @@ -39,7 +39,7 @@ namespace LLama.Grammar return value; } - public uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) + private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) { uint nextId = (uint)state.SymbolIds.Count; string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray()); @@ -55,7 +55,7 @@ namespace LLama.Grammar } } - public uint GenerateSymbolId(ParseState state, string baseName) + private uint GenerateSymbolId(ParseState state, string baseName) { uint nextId = (uint)state.SymbolIds.Count; string key = $"{baseName}_{nextId}"; @@ -63,7 +63,7 @@ namespace LLama.Grammar return nextId; } - public void AddRule(ParseState state, uint ruleId, List rule) + private void AddRule(ParseState state, uint ruleId, List rule) { while (state.Rules.Count <= ruleId) { @@ -73,12 +73,12 @@ namespace LLama.Grammar state.Rules[(int)ruleId] = rule; } - public bool IsWordChar(byte c) + private bool IsWordChar(byte c) { return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); } - public uint ParseHex(ref ReadOnlySpan src, int size) + private uint ParseHex(ref ReadOnlySpan src, int size) { int pos = 0; int end = size; @@ -114,7 +114,7 @@ namespace LLama.Grammar return value; } - public ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) + private ReadOnlySpan ParseSpace(ReadOnlySpan src, bool newlineOk) { int pos = 0; while (pos < src.Length && @@ -136,7 +136,7 @@ namespace LLama.Grammar return src.Slice(pos); } - public ReadOnlySpan ParseName(ReadOnlySpan src) + private ReadOnlySpan ParseName(ReadOnlySpan src) { int pos = 0; while (pos < src.Length && IsWordChar(src[pos])) @@ -150,30 +150,31 @@ namespace LLama.Grammar return src.Slice(pos); } - public uint ParseChar(ref ReadOnlySpan src) + private uint ParseChar(ref ReadOnlySpan src) { if (src[0] == '\\') { + var chr = src[1]; src = src.Slice(2); - switch ((char)src[1]) + switch (chr) { - case 'x': + case (byte)'x': return ParseHex(ref src, 2); - case 'u': + case (byte)'u': return ParseHex(ref src, 4); - case 'U': + case (byte)'U': return ParseHex(ref src, 8); - case 't': + case (byte)'t': return '\t'; - case 'r': + case (byte)'r': return '\r'; - case 'n': + case (byte)'n': return '\n'; - case '\\': - case '"': - case '[': - case ']': - return src[1]; + case (byte)'\\': + case (byte)'"': + case (byte)'[': + case (byte)']': + return chr; default: throw new Exception("Unknown escape at " + Encoding.UTF8.GetString(src.ToArray())); } @@ -186,7 +187,7 @@ namespace LLama.Grammar throw new Exception("Unexpected end of input"); } - public ReadOnlySpan ParseSequence( + private ReadOnlySpan ParseSequence( ParseState state, ReadOnlySpan pos, string ruleName, @@ -202,7 +203,7 @@ namespace LLama.Grammar pos = pos.Slice(1); lastSymStart = outElements.Count; - while (pos[0] != '"') + while (!pos.IsEmpty && pos[0] != '"') { var charPair = ParseChar(ref pos); outElements.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.CHAR, Value = charPair }); @@ -222,7 +223,7 @@ namespace LLama.Grammar lastSymStart = outElements.Count; - while (pos[0] != ']') + while (!pos.IsEmpty && pos[0] != ']') { var charPair = ParseChar(ref pos); var type = lastSymStart < outElements.Count ? LLamaGrammarElementType.CHAR_ALT : startType; @@ -315,7 +316,7 @@ namespace LLama.Grammar return pos; } - public ReadOnlySpan ParseAlternates( + private ReadOnlySpan ParseAlternates( ParseState state, ReadOnlySpan src, string ruleName, @@ -325,7 +326,7 @@ namespace LLama.Grammar var rule = new List(); ReadOnlySpan pos = ParseSequence(state, src, ruleName, rule, isNested); - while (pos[0] == '|') + while (!pos.IsEmpty && pos[0] == '|') { rule.Add(new LLamaGrammarElement { Type = LLamaGrammarElementType.ALT, Value = 0 }); pos = ParseSpace(pos.Slice(1), true); @@ -338,11 +339,11 @@ namespace LLama.Grammar return pos; } - public ReadOnlySpan ParseRule(ParseState state, ReadOnlySpan src) + private ReadOnlySpan ParseRule(ParseState state, ReadOnlySpan src) { ReadOnlySpan nameEnd = ParseName(src); ReadOnlySpan pos = ParseSpace(nameEnd, false); - int nameLen = nameEnd.Length - src.Length; + int nameLen = src.Length - nameEnd.Length; uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), nameLen); string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); @@ -354,25 +355,27 @@ namespace LLama.Grammar pos = ParseAlternates(state, pos, name, ruleId, false); - if (pos[0] == '\r') + if (!pos.IsEmpty && pos[0] == '\r') { pos = pos.Slice(pos[1] == '\n' ? 2 : 1); } - else if (pos[0] == '\n') + else if (!pos.IsEmpty && pos[0] == '\n') { pos = pos.Slice(1); } - else if (pos.Length > 0) + else if (!pos.IsEmpty) { throw new Exception($"Expecting newline or end at {Encoding.UTF8.GetString(pos.ToArray())}"); } return ParseSpace(pos, true); } - public ParseState Parse(ReadOnlySpan src) + public ParseState Parse(string input) { try { + byte[] byteArray = Encoding.UTF8.GetBytes(input); + ReadOnlySpan src = new ReadOnlySpan(byteArray); ParseState state = new ParseState(); ReadOnlySpan pos = ParseSpace(src, true); @@ -390,7 +393,7 @@ namespace LLama.Grammar } } - public void PrintGrammarChar(StreamWriter file, uint c) + private void PrintGrammarChar(StreamWriter file, uint c) { if (c >= 0x20 && c <= 0x7F) { @@ -403,7 +406,7 @@ namespace LLama.Grammar } } - public bool IsCharElement(LLamaGrammarElement elem) + private bool IsCharElement(LLamaGrammarElement elem) { switch (elem.Type) { @@ -451,7 +454,7 @@ namespace LLama.Grammar file.WriteLine(); } - public void PrintRule( + private void PrintRule( StreamWriter file, uint ruleId, List rule, diff --git a/LLama/Grammar/ParseState.cs b/LLama/Grammar/ParseState.cs index bd5eaf41..ff42f527 100644 --- a/LLama/Grammar/ParseState.cs +++ b/LLama/Grammar/ParseState.cs @@ -10,7 +10,7 @@ namespace LLama.Grammar /// /// The commit hash from URL is the actual commit hash that reflects current C# code. /// - internal class ParseState + public class ParseState { public Dictionary SymbolIds { get; } = new Dictionary(); public List> Rules { get; } = new List>(); diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs index d097628f..7c321c5d 100644 --- a/LLama/Native/LLamaGrammarElement.cs +++ b/LLama/Native/LLamaGrammarElement.cs @@ -1,4 +1,5 @@ -using System.Runtime.InteropServices; +using System.Diagnostics; +using System.Runtime.InteropServices; namespace LLama.Native { @@ -49,6 +50,7 @@ namespace LLama.Native /// An element of a grammar /// [StructLayout(LayoutKind.Sequential)] + [DebuggerDisplay("{Type} {Value}")] public struct LLamaGrammarElement { ///