From a60b7d26e419c55c81fcb61c572b78cc73eedd16 Mon Sep 17 00:00:00 2001 From: Archana Ramalingam <98564406+archana-ramalingam@users.noreply.github.com> Date: Thu, 7 Nov 2024 20:20:59 -0800 Subject: [PATCH] [sharktank] Fix numerics for perplexity with vmfb (#436) Fix cache update to resolve numerics issue. Update cache directly via IREE's DeviceArray. --- .../sharktank/evaluate/perplexity_vmfb.py | 71 +++--- sharktank/sharktank/utils/load_llm.py | 2 +- .../evaluate/baseline_perplexity_scores.json | 202 +++++++++--------- .../tests/evaluate/perplexity_vmfb_test.py | 2 +- 4 files changed, 143 insertions(+), 134 deletions(-) diff --git a/sharktank/sharktank/evaluate/perplexity_vmfb.py b/sharktank/sharktank/evaluate/perplexity_vmfb.py index fedf7c1c9..4f95ae1bd 100644 --- a/sharktank/sharktank/evaluate/perplexity_vmfb.py +++ b/sharktank/sharktank/evaluate/perplexity_vmfb.py @@ -73,8 +73,8 @@ def __init__( self.iree_hip_target = iree_hip_target self.iree_hal_target_backends = iree_hal_target_backends self.kv_cache_type = kv_cache_type - self.activation_dtype = torch.float32 - self.attention_dtype = torch.float32 + self.activation_dtype = torch.float16 + self.attention_dtype = torch.float16 self.tensor_parallelism_size = tensor_parallelism_size self.attention_kernel = attention_kernel @@ -166,6 +166,8 @@ def load_model(self, weight_path, tokenizer, vmfb_path): external_weight_path=self.weight_path_str, ) + self.haldevice = self.runner.config.device + @timeit def get_prompts(self): test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[ @@ -189,40 +191,19 @@ def get_prompts(self): def prefill_vmfb(self, token_batch, i): - logger.debug(f"Prefill:") - - logger.debug("Input:") - logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") - - token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( - token_ids=token_batch.tolist(), - pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, - ) - - logger.debug(f"{token_batch}") - - token_batch = torch.tensor(token_batch, device=self.torch_device) - self.seq_lens_batch = torch.tensor(seq_lens_batch, device=self.torch_device) - - self.batch = self.generator.begin_eval_batch( - token_batch=token_batch, - seq_lens_batch=self.seq_lens_batch, - bs=self.bs, - ) - seq_block_ids = self.batch.pad_block_ids() prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"]( token_batch, - self.seq_lens_batch, + self.batch.seq_lens, seq_block_ids, - self.batch.cache_state[0].to(torch.float16), + self.cache_state, ) prefill_logits = torch.tensor(prefill_logits[:, :, :]) tokens = torch.tensor( self.generator.model.extract_tokens_from_logits( - prefill_logits, seq_lens_batch + prefill_logits, self.batch.seq_lens ) ).unsqueeze(1) self.batch.add_result_token(tokens) @@ -237,17 +218,17 @@ def decode_vmfb(self, token_batch, i): logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") logger.debug(f"{token_batch.tolist()}") - start_positions = self.seq_lens_batch.clone() - self.seq_lens_batch.add_(1) + start_positions = self.batch.seq_lens.clone() + self.batch.seq_lens.add_(1) self.batch.allocate_seq_block_ids() seq_block_ids = self.batch.pad_block_ids() decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"]( token_batch, - self.seq_lens_batch, + self.batch.seq_lens, start_positions, seq_block_ids, - self.batch.cache_state[0].to(torch.float16), + self.cache_state, ) decode_logits = torch.tensor(decode_logits[:, :, :]) @@ -287,6 +268,7 @@ def get_logits(self): start = 0 for i in tqdm( range(start, self.max_prompt_length - 1), + mininterval=300, desc="eval: Calculating logits", ): logger.debug(f"Iteration: {i}") @@ -295,8 +277,35 @@ def get_logits(self): token_batch = self.token_ids[:, : i + 1] + logger.debug(f"Prefill:") + + logger.debug("Input:") + logger.debug(f"{self.generator.tokenizer.decode(token_batch)}") + + token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens( + token_ids=token_batch.tolist(), + pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride, + ) + + logger.debug(f"{token_batch}") + + token_batch = torch.tensor(token_batch, device=self.torch_device) + self.seq_lens_batch = torch.tensor( + seq_lens_batch, device=self.torch_device + ) + + self.batch = self.generator.begin_eval_batch( + token_batch=token_batch, + seq_lens_batch=self.seq_lens_batch, + bs=self.bs, + ) + + self.cache_state = ireert.asdevicearray( + self.haldevice, self.batch.cache_state[0].to("cpu").numpy() + ) + prefill_logits = self.prefill_vmfb(token_batch, i) - self.out_logits = prefill_logits[:, 0:1, :] + self.out_logits = prefill_logits[:, -1:, :] is_first_token = False diff --git a/sharktank/sharktank/utils/load_llm.py b/sharktank/sharktank/utils/load_llm.py index 558653d9b..acf56eb1b 100644 --- a/sharktank/sharktank/utils/load_llm.py +++ b/sharktank/sharktank/utils/load_llm.py @@ -31,9 +31,9 @@ def __init__( self.tokenizer = tokenizer if model.cache.is_paged: self.shared_cache_state = model.cache.paged.allocate(page_cache_size) + self.free_pages = list(range(1, page_cache_size)) else: self.shared_cache_state = None - self.free_pages = list(range(1, 8192)) self.end_token = end_token @property diff --git a/sharktank/tests/evaluate/baseline_perplexity_scores.json b/sharktank/tests/evaluate/baseline_perplexity_scores.json index b621fdf6d..ac2cd7b83 100644 --- a/sharktank/tests/evaluate/baseline_perplexity_scores.json +++ b/sharktank/tests/evaluate/baseline_perplexity_scores.json @@ -212,107 +212,107 @@ }, "llama3_8B_f16_decomposed_vmfb": { "perplexities": [ - 21194.505859, - 19049.068359, - 14214.751953, - 15752.748047, - 8948.568359, - 9867.280273, - 16664.880859, - 10607.53125, - 9715.395508, - 14289.220703, - 25121.929688, - 8545.292969, - 21990.28125, - 8150.422363, - 4658.82666, - 13440.376953, - 11978.756836, - 9100.139648, - 7168.022949, - 14279.970703, - 19406.207031, - 13816.291016, - 14942.27832, - 20922.1875, - 17307.214844, - 10634.068359, - 10968.188477, - 11322.012695, - 7898.733887, - 7532.914062, - 10352.375, - 16628.289062, - 5661.084473, - 6998.464355, - 7167.906738, - 7252.662598, - 7832.401367, - 5824.921875, - 12029.311523, - 13104.125, - 6688.567871, - 7917.172852, - 13455.291992, - 7466.178223, - 8360.422852, - 5765.317383, - 21530.652344, - 13371.045898, - 41826.242188, - 13620.586914, - 13886.725586, - 13105.150391, - 27155.019531, - 8066.837402, - 6860.444824, - 9858.532227, - 7352.963867, - 15839.926758, - 4746.95459, - 8539.133789, - 12957.833008, - 10096.874023, - 6436.333496, - 6488.447754, - 12649.62793, - 9575.267578, - 2897.279785, - 12649.941406, - 14139.443359, - 12061.751953, - 10646.621094, - 15703.19043, - 13080.764648, - 9124.349609, - 14409.989258, - 10726.665039, - 6444.680664, - 10168.352539, - 5474.356934, - 10729.345703, - 4240.486328, - 11856.861328, - 6184.834473, - 16671.128906, - 9840.30957, - 39691.976562, - 21551.833984, - 6072.709961, - 18333.572266, - 6635.820801, - 8460.941406, - 14243.955078, - 34157.90625, - 9565.474609, - 5573.206055, - 9139.364258, - 6077.837402, - 13941.31543, - 10590.963867, - 12113.441406 + 6.651368, + 22.059452, + 15.392176, + 17.418619, + 15.206824, + 7.907998, + 8.829535, + 22.355659, + 8.29262, + 20.958277, + 7.167404, + 14.592677, + 9.060788, + 7.274667, + 16.238981, + 6.666115, + 6.535679, + 7.086256, + 10.676177, + 8.979206, + 10.597121, + 42.038162, + 11.70071, + 65.731316, + 47.42622, + 20.109543, + 18.897541, + 13.781085, + 9.99165, + 5.955308, + 10.175659, + 23.628405, + 14.306578, + 9.719462, + 5.594786, + 14.198979, + 5.711433, + 17.381332, + 9.058512, + 8.286205, + 8.016202, + 18.4515, + 11.600831, + 3.945074, + 13.000222, + 10.373363, + 12.237907, + 21.408463, + 37.858665, + 25.794065, + 15.489001, + 14.004895, + 7.625473, + 10.993184, + 14.698832, + 11.062652, + 5.855446, + 15.625135, + 8.052419, + 14.365479, + 5.927001, + 6.931933, + 2.3014, + 15.769623, + 40.843319, + 8.022024, + 12.544907, + 10.090073, + 9.304819, + 10.679907, + 8.136175, + 21.540607, + 3.736973, + 15.381804, + 24.21562, + 14.385005, + 17.791706, + 16.498833, + 8.753955, + 12.941816, + 12.887664, + 13.725715, + 13.994792, + 10.769128, + 14.734674, + 26.970015, + 17.811842, + 9.847188, + 15.124973, + 15.623392, + 29.147844, + 12.309229, + 32.15152, + 33.225769, + 14.426914, + 17.496277, + 14.7356, + 15.503921, + 12.336852, + 16.469248 ], - "mean_perplexity": 12191.57833 + "mean_perplexity": 14.991893 } } diff --git a/sharktank/tests/evaluate/perplexity_vmfb_test.py b/sharktank/tests/evaluate/perplexity_vmfb_test.py index 58520fb31..93ffbe61c 100644 --- a/sharktank/tests/evaluate/perplexity_vmfb_test.py +++ b/sharktank/tests/evaluate/perplexity_vmfb_test.py @@ -22,7 +22,7 @@ class PerplexityTest(unittest.TestCase): def setUp(self): self.current_perplexity_all = {} - self.delta = 10 + self.delta = 5e-1 self.tensor_parallelism_size = 8 with open(self.baseline_perplexity_scores, "r") as f: self.baseline_perplexity = json.load(f)