Skip to content

Commit

Permalink
[sharktank] Fix numerics for perplexity with vmfb (#436)
Browse files Browse the repository at this point in the history
Fix cache update to resolve numerics issue. 
Update cache directly via IREE's DeviceArray.
  • Loading branch information
archana-ramalingam authored Nov 8, 2024
1 parent 445511a commit a60b7d2
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 134 deletions.
71 changes: 40 additions & 31 deletions sharktank/sharktank/evaluate/perplexity_vmfb.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ def __init__(
self.iree_hip_target = iree_hip_target
self.iree_hal_target_backends = iree_hal_target_backends
self.kv_cache_type = kv_cache_type
self.activation_dtype = torch.float32
self.attention_dtype = torch.float32
self.activation_dtype = torch.float16
self.attention_dtype = torch.float16
self.tensor_parallelism_size = tensor_parallelism_size
self.attention_kernel = attention_kernel

Expand Down Expand Up @@ -166,6 +166,8 @@ def load_model(self, weight_path, tokenizer, vmfb_path):
external_weight_path=self.weight_path_str,
)

self.haldevice = self.runner.config.device

@timeit
def get_prompts(self):
test_prompts = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")[
Expand All @@ -189,40 +191,19 @@ def get_prompts(self):

def prefill_vmfb(self, token_batch, i):

logger.debug(f"Prefill:")

logger.debug("Input:")
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")

token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
token_ids=token_batch.tolist(),
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)

logger.debug(f"{token_batch}")

token_batch = torch.tensor(token_batch, device=self.torch_device)
self.seq_lens_batch = torch.tensor(seq_lens_batch, device=self.torch_device)

self.batch = self.generator.begin_eval_batch(
token_batch=token_batch,
seq_lens_batch=self.seq_lens_batch,
bs=self.bs,
)

seq_block_ids = self.batch.pad_block_ids()
prefill_logits = self.runner.ctx.modules.module[f"prefill_bs{self.bs}"](
token_batch,
self.seq_lens_batch,
self.batch.seq_lens,
seq_block_ids,
self.batch.cache_state[0].to(torch.float16),
self.cache_state,
)

prefill_logits = torch.tensor(prefill_logits[:, :, :])

tokens = torch.tensor(
self.generator.model.extract_tokens_from_logits(
prefill_logits, seq_lens_batch
prefill_logits, self.batch.seq_lens
)
).unsqueeze(1)
self.batch.add_result_token(tokens)
Expand All @@ -237,17 +218,17 @@ def decode_vmfb(self, token_batch, i):
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")
logger.debug(f"{token_batch.tolist()}")

start_positions = self.seq_lens_batch.clone()
self.seq_lens_batch.add_(1)
start_positions = self.batch.seq_lens.clone()
self.batch.seq_lens.add_(1)
self.batch.allocate_seq_block_ids()
seq_block_ids = self.batch.pad_block_ids()

decode_logits = self.runner.ctx.modules.module[f"decode_bs{self.bs}"](
token_batch,
self.seq_lens_batch,
self.batch.seq_lens,
start_positions,
seq_block_ids,
self.batch.cache_state[0].to(torch.float16),
self.cache_state,
)

decode_logits = torch.tensor(decode_logits[:, :, :])
Expand Down Expand Up @@ -287,6 +268,7 @@ def get_logits(self):
start = 0
for i in tqdm(
range(start, self.max_prompt_length - 1),
mininterval=300,
desc="eval: Calculating logits",
):
logger.debug(f"Iteration: {i}")
Expand All @@ -295,8 +277,35 @@ def get_logits(self):

token_batch = self.token_ids[:, : i + 1]

logger.debug(f"Prefill:")

logger.debug("Input:")
logger.debug(f"{self.generator.tokenizer.decode(token_batch)}")

token_batch, seq_lens_batch = self.generator.tokenizer.pad_tokens(
token_ids=token_batch.tolist(),
pad_to_multiple_of=self.generator.model.cache.pad_sequence_stride,
)

logger.debug(f"{token_batch}")

token_batch = torch.tensor(token_batch, device=self.torch_device)
self.seq_lens_batch = torch.tensor(
seq_lens_batch, device=self.torch_device
)

self.batch = self.generator.begin_eval_batch(
token_batch=token_batch,
seq_lens_batch=self.seq_lens_batch,
bs=self.bs,
)

self.cache_state = ireert.asdevicearray(
self.haldevice, self.batch.cache_state[0].to("cpu").numpy()
)

prefill_logits = self.prefill_vmfb(token_batch, i)
self.out_logits = prefill_logits[:, 0:1, :]
self.out_logits = prefill_logits[:, -1:, :]

is_first_token = False

Expand Down
2 changes: 1 addition & 1 deletion sharktank/sharktank/utils/load_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ def __init__(
self.tokenizer = tokenizer
if model.cache.is_paged:
self.shared_cache_state = model.cache.paged.allocate(page_cache_size)
self.free_pages = list(range(1, page_cache_size))
else:
self.shared_cache_state = None
self.free_pages = list(range(1, 8192))
self.end_token = end_token

@property
Expand Down
202 changes: 101 additions & 101 deletions sharktank/tests/evaluate/baseline_perplexity_scores.json
Original file line number Diff line number Diff line change
Expand Up @@ -212,107 +212,107 @@
},
"llama3_8B_f16_decomposed_vmfb": {
"perplexities": [
21194.505859,
19049.068359,
14214.751953,
15752.748047,
8948.568359,
9867.280273,
16664.880859,
10607.53125,
9715.395508,
14289.220703,
25121.929688,
8545.292969,
21990.28125,
8150.422363,
4658.82666,
13440.376953,
11978.756836,
9100.139648,
7168.022949,
14279.970703,
19406.207031,
13816.291016,
14942.27832,
20922.1875,
17307.214844,
10634.068359,
10968.188477,
11322.012695,
7898.733887,
7532.914062,
10352.375,
16628.289062,
5661.084473,
6998.464355,
7167.906738,
7252.662598,
7832.401367,
5824.921875,
12029.311523,
13104.125,
6688.567871,
7917.172852,
13455.291992,
7466.178223,
8360.422852,
5765.317383,
21530.652344,
13371.045898,
41826.242188,
13620.586914,
13886.725586,
13105.150391,
27155.019531,
8066.837402,
6860.444824,
9858.532227,
7352.963867,
15839.926758,
4746.95459,
8539.133789,
12957.833008,
10096.874023,
6436.333496,
6488.447754,
12649.62793,
9575.267578,
2897.279785,
12649.941406,
14139.443359,
12061.751953,
10646.621094,
15703.19043,
13080.764648,
9124.349609,
14409.989258,
10726.665039,
6444.680664,
10168.352539,
5474.356934,
10729.345703,
4240.486328,
11856.861328,
6184.834473,
16671.128906,
9840.30957,
39691.976562,
21551.833984,
6072.709961,
18333.572266,
6635.820801,
8460.941406,
14243.955078,
34157.90625,
9565.474609,
5573.206055,
9139.364258,
6077.837402,
13941.31543,
10590.963867,
12113.441406
6.651368,
22.059452,
15.392176,
17.418619,
15.206824,
7.907998,
8.829535,
22.355659,
8.29262,
20.958277,
7.167404,
14.592677,
9.060788,
7.274667,
16.238981,
6.666115,
6.535679,
7.086256,
10.676177,
8.979206,
10.597121,
42.038162,
11.70071,
65.731316,
47.42622,
20.109543,
18.897541,
13.781085,
9.99165,
5.955308,
10.175659,
23.628405,
14.306578,
9.719462,
5.594786,
14.198979,
5.711433,
17.381332,
9.058512,
8.286205,
8.016202,
18.4515,
11.600831,
3.945074,
13.000222,
10.373363,
12.237907,
21.408463,
37.858665,
25.794065,
15.489001,
14.004895,
7.625473,
10.993184,
14.698832,
11.062652,
5.855446,
15.625135,
8.052419,
14.365479,
5.927001,
6.931933,
2.3014,
15.769623,
40.843319,
8.022024,
12.544907,
10.090073,
9.304819,
10.679907,
8.136175,
21.540607,
3.736973,
15.381804,
24.21562,
14.385005,
17.791706,
16.498833,
8.753955,
12.941816,
12.887664,
13.725715,
13.994792,
10.769128,
14.734674,
26.970015,
17.811842,
9.847188,
15.124973,
15.623392,
29.147844,
12.309229,
32.15152,
33.225769,
14.426914,
17.496277,
14.7356,
15.503921,
12.336852,
16.469248
],
"mean_perplexity": 12191.57833
"mean_perplexity": 14.991893
}
}
2 changes: 1 addition & 1 deletion sharktank/tests/evaluate/perplexity_vmfb_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class PerplexityTest(unittest.TestCase):
def setUp(self):
self.current_perplexity_all = {}
self.delta = 10
self.delta = 5e-1
self.tensor_parallelism_size = 8
with open(self.baseline_perplexity_scores, "r") as f:
self.baseline_perplexity = json.load(f)
Expand Down

0 comments on commit a60b7d2

Please sign in to comment.