Skip to content

Commit

Permalink
Option to not use DFA mask store
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamugare committed Oct 17, 2024
1 parent af13eef commit 41d9cc8
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 12 deletions.
18 changes: 11 additions & 7 deletions syncode/dfa_mask_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,8 @@ def __init__(self,
special_token_ids: Iterable=[],
indentation: bool=True,
mode='grammar_strict', # 'grammar_strict' or 'grammar_mask'
ignore_terminals: Iterable[str]=[]
ignore_terminals: Iterable[str]=[],
use_mask_store: bool=True
):
self._vocab = vocab
self.special_token_ids = special_token_ids
Expand All @@ -305,7 +306,9 @@ def __init__(self,
# Iterate through each pair of DFA state and next terminals and store the overapproximate tokens
self._lookup_table = LookupTable(vocab, special_token_ids, indentation=indentation, mode=mode)
terminal_names = [terminal.name for terminal in terminals]
self._store_overapproximate_tokens(terminal_names, vocab)

if use_mask_store:
self._store_overapproximate_tokens(terminal_names, vocab)

self.indentation = indentation

Expand All @@ -322,7 +325,7 @@ def set_ignore_whitespace(self, terminals: Iterable[TerminalDef], ignore_termina
return ignore_whitespace

@staticmethod
def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None, mode='grammar_strict'):
def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None, mode='grammar_strict', use_mask_store=True):
'''
Loads the dfa for the given language and tokenizer. If the dfa is not cached, it is created and cached.
'''
Expand All @@ -331,24 +334,25 @@ def load_dfa_mask_store(grammar: Grammar, tokenizer, use_cache=True, logger=None
grammar_hash = grammar.hash()

# TODO: Hashing using the tokenizer vocab size, this may be problmatic if we have two fine-tuned models with same tokenizer, same vocab size but different vocab
dfa_path = f'{dfa_dir}{mode}_{grammar_hash}_{tokenizer.vocab_size}.pkl'
dfa_path = f'{dfa_dir}{mode}_{grammar_hash}_{tokenizer.vocab_size}_lookup_table:{use_mask_store}.pkl'

if use_cache and os.path.exists(dfa_path):
try:
mask_store = pickle.load(open(dfa_path, 'rb'))
return mask_store
except: # If we cannot load the file, we will create the dfa from scratch
pass

print(f"Creating DFA mask store for {tokenizer_name} and {grammar}, may take more than 10 minutes. Caching at {os.path.abspath(dfa_path)}.", flush=True)

if use_mask_store:
print(f"Creating DFA mask store for {tokenizer_name} and {grammar}, may take more than 10 minutes. Caching at {os.path.abspath(dfa_path)}.", flush=True)
vocab = common.get_vocab_from_tokenizer(tokenizer)

base_parser = create_base_parser(grammar)

simplifications = grammar.simplifications()
os.makedirs(dfa_dir, exist_ok=True)

mask_store = DFAMaskStore(base_parser.terminals, vocab, simplifications=simplifications, special_token_ids=[tokenizer.eos_token_id], mode=mode, ignore_terminals=base_parser.ignore_tokens)
mask_store = DFAMaskStore(base_parser.terminals, vocab, simplifications=simplifications, special_token_ids=[tokenizer.eos_token_id], mode=mode, ignore_terminals=base_parser.ignore_tokens, use_mask_store=use_mask_store)

pickle.dump(mask_store, open(dfa_path, 'wb'))
return mask_store
Expand Down
29 changes: 26 additions & 3 deletions syncode/grammar_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(self,
num_samples=1,
dev_mode=False,
parser='lalr',
mode='grammar_mask'):
mode='grammar_mask',
use_mask_store=True
):

self.tokenizer = tokenizer
self.grammar = grammar
Expand All @@ -50,14 +52,16 @@ def __init__(self,
self._ignore_whitespace = self._get_ignore_whitespace(self.grammar)

# Load dfa mask store
self.use_mask_store = use_mask_store
self.dfa_mask_store = DFAMaskStore.load_dfa_mask_store(
grammar=self.grammar,
tokenizer=self.tokenizer,
use_cache=use_cache,
logger=self.logger,
mode=mode,
use_mask_store=use_mask_store
)

# Create parser
self.inc_parser: IncrementalParser = create_parser(self.grammar, logger=self.logger, parser=parser, ignore_whitespace=self._ignore_whitespace)

Expand Down Expand Up @@ -136,6 +140,15 @@ def is_valid(self, input_ids: torch.LongTensor, next_token: torch.LongTensor) ->
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
# start_from is used for choosing where the parsing should start
debug = True

if self.use_mask_store:
self._use_mask_store_to_update_scores(input_ids, scores, debug)
else:
self._use_simple_update_scores(input_ids, scores, debug)

return scores

def _use_mask_store_to_update_scores(self, input_ids, scores, debug):
partial_codes = self._get_partial_codes(input_ids)

for idx, partial_code in enumerate(partial_codes):
Expand Down Expand Up @@ -169,7 +182,17 @@ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> to
# For debugging - remove later
if debug: self._debug_greedy(scores, idx, partial_code, r, greedy_token)

return scores
def _use_simple_update_scores(self, input_ids, scores, debug):
"""
This is a simple algorithms where we iterate
"""
for idx in range(len(input_ids)):
for token_id in range(len(scores[idx])):
if self.is_valid(input_ids[idx:idx+1], torch.tensor([token_id]).to(input_ids.device)):
scores[idx, token_id] = scores[idx, token_id]
else:
scores[idx, token_id] = -float("inf")


def _get_partial_codes(self, input_ids: torch.LongTensor):
assert self.start_from <= input_ids.size(1), "Make sure that the decoder is reset for new prompt."
Expand Down
10 changes: 8 additions & 2 deletions syncode/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
from syncode.evaluation.fol_eval import FOLEval


def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", num_samples=1, grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", task_id=None, seed=None, **kwargs):
def compile_and_run(model, mode="grammar_strict", quantize=True, device="cuda", num_samples=1, grammar=None, dataset="input", num_few_shot=0, chat_mode=False, dev_mode=False, log_level=1, new_mask_store=False, parser="lalr", task_id=None, seed=None, use_mask_store=True, **kwargs):

syncode = Syncode(model, mode=mode, quantize=quantize, device=device, num_samples=num_samples, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, **kwargs)
syncode = Syncode(model, mode=mode, quantize=quantize, device=device, num_samples=num_samples, grammar=grammar, chat_mode=chat_mode, dev_mode=dev_mode, log_level=log_level, new_mask_store=new_mask_store, parser=parser, seed=seed, use_mask_store=use_mask_store, **kwargs)

if dataset == "input":
syncode.infer()
Expand Down Expand Up @@ -57,6 +57,10 @@ class Syncode:
log_level (int, optional): Log level. Defaults to 2. 0 for no logs, 1 for minimal logs, 2 for all logs including time.
parser (str, optional): Parser to use. Defaults to "lalr".
use_mask_store (bool, optional): Use mask store. Defaults to True.
seed (int, optional): Seed for random number generator. Defaults to None.
"""
def __init__(
self,
Expand All @@ -73,6 +77,7 @@ def __init__(
new_mask_store: bool = False,
parser: Literal["lr", "lalr"] = "lalr",
seed: Optional[int] = None,
use_mask_store: bool = True,
**kwargs
):
# Check inputs
Expand Down Expand Up @@ -122,6 +127,7 @@ def __init__(
dev_mode=dev_mode,
parser=parser,
mode=mode,
use_mask_store=use_mask_store,
)

# Set LLM max new tokens to 200 by default
Expand Down
36 changes: 36 additions & 0 deletions syncode/temp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

# Test speeds of various LLMs on CPU and GPU
import time
import common

# Load model
def run_model(model_name, prompt: str, device: str, quantize: bool):
model = common.load_model(model_name, device, quantize)
tokenizer = common.load_tokenizer(model_name)

model = model.to(device)

input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)

start_time = time.time()
output = model.generate(input_ids, max_new_tokens=100)
end_time = time.time()

total_time = end_time - start_time
num_tokens = len(tokenizer.decode(output[0], skip_special_tokens=True))
tokens_per_sec = num_tokens / total_time

print("Results for:", model_name, "device =", device, "with quantize =", quantize, flush=True)
print(f"Time taken: {end_time - start_time:.2f}s", flush=True)
print(f"Number of tokens: {num_tokens}", flush=True)
print(f"Tokens per second: {tokens_per_sec:.2f}", flush=True)


models = models = ["Qwen/Qwen2.5-0.5B", "Qwen/Qwen2.5-1.5B", "Qwen/Qwen2.5-Coder-1.5B", "Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-Coder-7B", "meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-3B"]
devices = ["cpu", "cuda"]
quantizations = [True, False]

for model_name in models:
for device in devices:
for quantize in quantizations:
run_model(model_name, "1, 2, 3, 4", device, quantize)

0 comments on commit 41d9cc8

Please sign in to comment.