From 2ae1891c13953dfb693e927f072ce3801d02b6d8 Mon Sep 17 00:00:00 2001 From: Mihai Date: Wed, 30 Aug 2023 16:18:05 +0300 Subject: [PATCH] Bug fixes after running tests. SymbolIds is now SortedDictionary (although I'm not sure it really needs to be) because the test was failing due to expected value being in another order. The C++ data structure if SymbolIds is std::map so the items are ordered by key. --- LLama.Unittest/GrammarParserTest.cs | 2 +- LLama/Grammar/GrammarParser.cs | 4 ++-- LLama/Grammar/ParseState.cs | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/LLama.Unittest/GrammarParserTest.cs b/LLama.Unittest/GrammarParserTest.cs index 926e51a9..673a7164 100644 --- a/LLama.Unittest/GrammarParserTest.cs +++ b/LLama.Unittest/GrammarParserTest.cs @@ -68,7 +68,7 @@ namespace LLama.Unittest new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 4), - new LLamaGrammarElement(LLamaGrammarElementType.END, 0), + new LLamaGrammarElement(LLamaGrammarElementType.ALT, 0), new LLamaGrammarElement(LLamaGrammarElementType.RULE_REF, 1), new LLamaGrammarElement(LLamaGrammarElementType.END, 0), new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 45), diff --git a/LLama/Grammar/GrammarParser.cs b/LLama/Grammar/GrammarParser.cs index 3e038992..a59d6e1a 100644 --- a/LLama/Grammar/GrammarParser.cs +++ b/LLama/Grammar/GrammarParser.cs @@ -42,7 +42,7 @@ namespace LLama.Grammar private uint GetSymbolId(ParseState state, ReadOnlySpan src, int len) { uint nextId = (uint)state.SymbolIds.Count; - string key = Encoding.UTF8.GetString(src.Slice(0, len).ToArray()); + string key = Encoding.UTF8.GetString(src.Slice(0, src.Length - len).ToArray()); if (state.SymbolIds.TryGetValue(key, out uint existingId)) { @@ -344,7 +344,7 @@ namespace LLama.Grammar ReadOnlySpan nameEnd = ParseName(src); ReadOnlySpan pos = ParseSpace(nameEnd, false); int nameLen = src.Length - nameEnd.Length; - uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), nameLen); + uint ruleId = GetSymbolId(state, src.Slice(0, nameLen), 0); string name = Encoding.UTF8.GetString(src.Slice(0, nameLen).ToArray()); if (!(pos[0] == ':' && pos[1] == ':' && pos[2] == '=')) diff --git a/LLama/Grammar/ParseState.cs b/LLama/Grammar/ParseState.cs index ff42f527..f85fa032 100644 --- a/LLama/Grammar/ParseState.cs +++ b/LLama/Grammar/ParseState.cs @@ -12,7 +12,7 @@ namespace LLama.Grammar /// public class ParseState { - public Dictionary SymbolIds { get; } = new Dictionary(); + public SortedDictionary SymbolIds { get; } = new SortedDictionary(); public List> Rules { get; } = new List>(); public IEnumerable> CRules()