From cee8c0fd471b4e65bb60f8ee0f3a8d35b64f3524 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 22 Jan 2024 15:46:47 -0500 Subject: [PATCH 01/16] Add draft model param to llama class, implement basic prompt lookup decoding draft model --- llama_cpp/llama.py | 17 +++++++-- llama_cpp/llama_speculative.py | 62 +++++++++++++++++++++++++++++++++ tests/test_llama_speculative.py | 16 +++++++++ 3 files changed, 92 insertions(+), 3 deletions(-) create mode 100644 llama_cpp/llama_speculative.py create mode 100644 tests/test_llama_speculative.py diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 25abf36cb3..84f3f5ab6a 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -30,6 +30,8 @@ import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format +from llama_cpp.llama_speculative import LlamaDraftModel + import numpy as np import numpy.typing as npt @@ -89,6 +91,8 @@ def __init__( # Chat Format Params chat_format: str = "llama-2", chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, + # Speculative Decoding + draft_model: Optional[LlamaDraftModel] = None, # Misc verbose: bool = True, # Extra Params @@ -152,6 +156,7 @@ def __init__( numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init) chat_format: String specifying the chat format to use when calling create_chat_completion. chat_handler: Optional chat handler to use when calling create_chat_completion. + draft_model: Optional draft model to use for speculative decoding. verbose: Print verbose output to stderr. Raises: @@ -197,7 +202,9 @@ def __init__( self.kv_overrides = kv_overrides if kv_overrides is not None: n_overrides = len(kv_overrides) - self._kv_overrides_array = llama_cpp.llama_model_kv_override * (n_overrides + 1) + self._kv_overrides_array = llama_cpp.llama_model_kv_override * ( + n_overrides + 1 + ) self._kv_overrides_array_keys = [] for k, v in kv_overrides.items(): @@ -216,10 +223,12 @@ def __init__( else: raise ValueError(f"Unknown value type for {k}: {v}") - self._kv_overrides_array_sentinel_key = b'\0' + self._kv_overrides_array_sentinel_key = b"\0" # null array sentinel - self._kv_overrides_array[n_overrides].key = self._kv_overrides_array_sentinel_key + self._kv_overrides_array[ + n_overrides + ].key = self._kv_overrides_array_sentinel_key self.model_params.kv_overrides = self._kv_overrides_array self.n_batch = min(n_ctx, n_batch) # ??? @@ -315,6 +324,8 @@ def __init__( self.chat_format = chat_format self.chat_handler = chat_handler + self.draft_model = draft_model + self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() diff --git a/llama_cpp/llama_speculative.py b/llama_cpp/llama_speculative.py new file mode 100644 index 0000000000..5df1937e65 --- /dev/null +++ b/llama_cpp/llama_speculative.py @@ -0,0 +1,62 @@ +import abc + +import numpy as np +import numpy.typing as npt + + +class LlamaDraftModel(abc.ABC): + @abc.abstractmethod + def __call__(self, input_ids: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]: + raise NotImplementedError() + + +class LlamaPromptLookupDecoding(LlamaDraftModel): + """Based on https://github.com/apoorvumang/prompt-lookup-decoding""" + + def __init__(self, max_ngram_size: int = 3, num_pred_tokens: int = 10): + self.max_ngram_size = max_ngram_size + self.num_pred_tokens = num_pred_tokens + + @staticmethod + def find_candidate_pred_tokens( + input_ids: npt.NDArray[np.intc], + max_ngram_size: int = 3, + num_pred_tokens: int = 10, + ): + input_length = input_ids.shape[1] + + for ngram_size in range(max_ngram_size, 0, -1): + # Extract the last n tokens as our search ngram + ngram = input_ids[0, -ngram_size:] + + # Create sliding windows of size ngram_size + windows = np.lib.stride_tricks.sliding_window_view( + input_ids, (1, ngram_size) + ) + + # Convert ngram to an array for comparison + ngram_array = np.array(ngram).reshape(1, -1) + + # Find where the windows match the ngram + matches = np.all(windows == ngram_array, axis=2) + + # Get the indices of matches + match_indices = np.nonzero(matches)[1] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + num_pred_tokens + # Ensure we don't go beyond the length of input_ids and avoid self-match + if end_idx <= input_length and start_idx < input_length - ngram_size: + return input_ids[0, start_idx:end_idx] + + # If no match is found, return an empty array + return np.array([], dtype=np.intc) + + def __call__(self, input_ids: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]: + return self.find_candidate_pred_tokens( + input_ids=input_ids, + max_ngram_size=self.max_ngram_size, + num_pred_tokens=self.num_pred_tokens, + ) diff --git a/tests/test_llama_speculative.py b/tests/test_llama_speculative.py new file mode 100644 index 0000000000..e2070518d3 --- /dev/null +++ b/tests/test_llama_speculative.py @@ -0,0 +1,16 @@ +import numpy as np + +from llama_cpp.llama_speculative import LlamaPromptLookupDecoding + +def test_find_candidate_pred_tokens(): + find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens + + # Test Case 1: Matching ngram is found + input_ids1 = np.array([[1, 2, 3, 1, 2, 3, 1, 2, 3]]) + result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2) + assert np.array_equal(result1, np.array([1, 2])) + + # Test Case 2: Matching ngram is not found + input_ids2 = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]]) + result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2) + assert np.array_equal(result2, np.array([])) From be688daee8ada08d4b4fb996e3bdfd03823204e7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 22 Jan 2024 16:23:40 -0500 Subject: [PATCH 02/16] Use samplingcontext for sampling --- llama_cpp/llama.py | 89 ++++++++++++++-------------------------------- 1 file changed, 27 insertions(+), 62 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3c6822cf75..ba77d3da38 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -41,6 +41,8 @@ _LlamaContext, # type: ignore _LlamaBatch, # type: ignore _LlamaTokenDataArray, # type: ignore + _LlamaSamplingParams, # type: ignore + _LlamaSamplingContext, # type: ignore ) @@ -491,79 +493,42 @@ def sample( """ assert self._ctx is not None assert self.n_tokens > 0 - last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - self.n_tokens - ) + self._input_ids[-self.last_n_tokens_size :].tolist() - last_n_tokens_size = len(last_n_tokens_data) - n_vocab = self._n_vocab - n_ctx = self._n_ctx - top_k = n_vocab if top_k <= 0 else top_k - last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size - last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)( - *last_n_tokens_data - ) + logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: logits[:] = logits_processor(self._input_ids, logits) - nl_logit = logits[self._token_nl] - self._candidates.copy_logits(logits) - self._ctx.sample_repetition_penalties( - candidates=self._candidates, - last_tokens_data=last_n_tokens_data_c, - penalty_last_n=last_n_tokens_size, + sampling_params = _LlamaSamplingParams( + top_k=top_k, + top_p=top_p, + min_p=min_p, + tfs_z=tfs_z, + typical_p=typical_p, + temp=temp, + penalty_last_n=self.last_n_tokens_size, penalty_repeat=repeat_penalty, penalty_freq=frequency_penalty, penalty_present=presence_penalty, + mirostat=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + penalize_nl=penalize_nl, + ) + sampling_context = _LlamaSamplingContext( + params=sampling_params, + grammar=grammar, + ) + sampling_context.prev = list(self.eval_tokens) + id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits) + sampling_context.accept( + ctx_main=self._ctx, + id=id, + apply_grammar=grammar is not None, ) - if not penalize_nl: - self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float( - nl_logit - ) - - if grammar is not None: - self._ctx.sample_grammar( - candidates=self._candidates, - grammar=grammar, - ) - - if temp < 0.0: - self._ctx.sample_softmax(candidates=self._candidates) - id = self._candidates.candidates.data[0].id - elif temp == 0.0: - id = self._ctx.sample_token_greedy(candidates=self._candidates) - elif mirostat_mode == 1: - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=ctypes.pointer(self._mirostat_mu), - m=100, - ) - elif mirostat_mode == 2: - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat_v2( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=ctypes.pointer(self._mirostat_mu) - ) - else: - self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) - self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) - self._ctx.sample_typical( - candidates=self._candidates, p=typical_p, min_keep=1 - ) - self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) - self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1) - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token(candidates=self._candidates) - if grammar is not None: - self._ctx.grammar_accept_token(grammar=grammar, token=id) return id + def generate( self, tokens: Sequence[int], From 8fe1c48b40b80a9fb1e9f77a22b4483e09f0efee Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Jan 2024 12:07:00 -0500 Subject: [PATCH 03/16] Use 1d array --- llama_cpp/llama_speculative.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama_speculative.py b/llama_cpp/llama_speculative.py index 5df1937e65..4f21c0fc6f 100644 --- a/llama_cpp/llama_speculative.py +++ b/llama_cpp/llama_speculative.py @@ -23,25 +23,28 @@ def find_candidate_pred_tokens( max_ngram_size: int = 3, num_pred_tokens: int = 10, ): - input_length = input_ids.shape[1] + input_length = input_ids.shape[0] + + if input_length < max_ngram_size: + return np.array([], dtype=np.intc) for ngram_size in range(max_ngram_size, 0, -1): # Extract the last n tokens as our search ngram - ngram = input_ids[0, -ngram_size:] + ngram = input_ids[-ngram_size:] # Create sliding windows of size ngram_size windows = np.lib.stride_tricks.sliding_window_view( - input_ids, (1, ngram_size) + input_ids, (ngram_size,) ) # Convert ngram to an array for comparison - ngram_array = np.array(ngram).reshape(1, -1) + ngram_array = np.array(ngram)#.reshape(1, -1) # Find where the windows match the ngram - matches = np.all(windows == ngram_array, axis=2) + matches = np.all(windows == ngram_array, axis=1) # Get the indices of matches - match_indices = np.nonzero(matches)[1] + match_indices = np.nonzero(matches)[0] # Iterate through match indices to find a valid continuation for idx in match_indices: @@ -49,7 +52,7 @@ def find_candidate_pred_tokens( end_idx = start_idx + num_pred_tokens # Ensure we don't go beyond the length of input_ids and avoid self-match if end_idx <= input_length and start_idx < input_length - ngram_size: - return input_ids[0, start_idx:end_idx] + return input_ids[start_idx:end_idx] # If no match is found, return an empty array return np.array([], dtype=np.intc) From 92cf2c409b5d7109644b9f8184a63fc225adf065 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Jan 2024 12:07:38 -0500 Subject: [PATCH 04/16] Use draft model for sampling --- llama_cpp/llama.py | 100 ++++++++++++++++++++++---------- llama_cpp/server/model.py | 7 +++ llama_cpp/server/settings.py | 5 ++ tests/test_llama_speculative.py | 4 +- 4 files changed, 82 insertions(+), 34 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ba77d3da38..6eca58c4cf 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -41,8 +41,8 @@ _LlamaContext, # type: ignore _LlamaBatch, # type: ignore _LlamaTokenDataArray, # type: ignore - _LlamaSamplingParams, # type: ignore - _LlamaSamplingContext, # type: ignore + _LlamaSamplingParams, # type: ignore + _LlamaSamplingContext, # type: ignore ) @@ -342,7 +342,9 @@ def __init__( (n_ctx, self._n_vocab), dtype=np.single ) - self._mirostat_mu = ctypes.c_float(2.0 * 5.0) # TODO: Move this to sampling context + self._mirostat_mu = ctypes.c_float( + 2.0 * 5.0 + ) # TODO: Move this to sampling context try: self.metadata = self._model.metadata() @@ -350,7 +352,7 @@ def __init__( self.metadata = {} if self.verbose: print(f"Failed to load metadata: {e}", file=sys.stderr) - + if self.verbose: print(f"Model metadata: {self.metadata}", file=sys.stderr) @@ -479,6 +481,7 @@ def sample( penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + idx: Optional[int] = None, ): """Sample a token from the model. @@ -494,10 +497,17 @@ def sample( assert self._ctx is not None assert self.n_tokens > 0 - logits: npt.NDArray[np.single] = self._scores[-1, :] + if idx is None: + logits: npt.NDArray[np.single] = self._scores[-1, :] + else: + logits = self._scores[idx, :] if logits_processor is not None: - logits[:] = logits_processor(self._input_ids, logits) + logits[:] = ( + logits_processor(self._input_ids, logits) + if idx is None + else logits_processor(self._input_ids[:idx], logits) + ) sampling_params = _LlamaSamplingParams( top_k=top_k, @@ -528,7 +538,6 @@ def sample( ) return id - def generate( self, tokens: Sequence[int], @@ -595,34 +604,61 @@ def generate( if grammar is not None: grammar.reset() + sample_idx = self.n_tokens + len(tokens) - 1 + draft_model = self.draft_model + candidates: List[int] = [] + candidates_idx = 0 + tokens = list(tokens) + # Eval and sample while True: self.eval(tokens) - token = self.sample( - top_k=top_k, - top_p=top_p, - min_p=min_p, - typical_p=typical_p, - temp=temp, - repeat_penalty=repeat_penalty, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - logits_processor=logits_processor, - grammar=grammar, - penalize_nl=penalize_nl, - ) - if stopping_criteria is not None and stopping_criteria( - self._input_ids, self._scores[-1, :] - ): - return - tokens_or_none = yield token - tokens = [token] - if tokens_or_none is not None: - tokens.extend(tokens_or_none) + while sample_idx < self.n_tokens: + token = self.sample( + top_k=top_k, + top_p=top_p, + min_p=min_p, + typical_p=typical_p, + temp=temp, + repeat_penalty=repeat_penalty, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + logits_processor=logits_processor, + grammar=grammar, + penalize_nl=penalize_nl, + idx=sample_idx, + ) + + sample_idx += 1 + if stopping_criteria is not None and stopping_criteria( + self._input_ids, self._scores[-1, :] + ): + return + tokens_or_none = yield token + tokens = [token] + if tokens_or_none is not None: + tokens.extend(tokens_or_none) + + if candidates and token != candidates[candidates_idx]: + candidates = [] + candidates_idx = 0 + self.n_tokens = sample_idx + self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + break + else: + candidates_idx += 1 + + if draft_model is not None and sample_idx >= self.n_tokens: + candidates = [int(i) for i in draft_model(self._input_ids[:sample_idx])[ + : self._n_ctx - self.n_tokens + ]] + candidates_idx = 0 + tokens.extend(candidates) + print(candidates) def create_embedding( self, input: Union[str, List[str]], model: Optional[str] = None diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index bbb68069d5..32e841686e 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -5,6 +5,7 @@ from typing import Dict, Optional, Union, List import llama_cpp +import llama_cpp.llama_speculative as llama_speculative from llama_cpp.server.settings import ModelSettings @@ -92,6 +93,10 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: ) ) + draft_model = None + if settings.draft_model is not None: + draft_model = llama_speculative.LlamaPromptLookupDecoding() + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None if settings.kv_overrides is not None: assert isinstance(settings.kv_overrides, list) @@ -147,6 +152,8 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: # Chat Format Params chat_format=settings.chat_format, chat_handler=chat_handler, + # Speculative Decoding + draft_model=draft_model, # Misc verbose=settings.verbose, ) diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 9f0dc8a73c..dfa48aee65 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -143,6 +143,11 @@ class ModelSettings(BaseSettings): default=None, description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().", ) + # Speculative Decoding + draft_model: Optional[str] = Field( + default=None, + description="Method to use for speculative decoding. One of (prompt-lookup-decoding).", + ) # Misc verbose: bool = Field( default=True, description="Whether to print debug information." diff --git a/tests/test_llama_speculative.py b/tests/test_llama_speculative.py index e2070518d3..b5d450567b 100644 --- a/tests/test_llama_speculative.py +++ b/tests/test_llama_speculative.py @@ -6,11 +6,11 @@ def test_find_candidate_pred_tokens(): find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens # Test Case 1: Matching ngram is found - input_ids1 = np.array([[1, 2, 3, 1, 2, 3, 1, 2, 3]]) + input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]) result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2) assert np.array_equal(result1, np.array([1, 2])) # Test Case 2: Matching ngram is not found - input_ids2 = np.array([[1, 2, 3, 4, 5, 6, 7, 8, 9]]) + input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2) assert np.array_equal(result2, np.array([])) From b4976daa3c7554af35df8a7146400f30845c527c Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Jan 2024 17:11:47 -0500 Subject: [PATCH 05/16] Fix dumb mistake --- llama_cpp/llama.py | 28 ++++++++++------------------ 1 file changed, 10 insertions(+), 18 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6eca58c4cf..9372b27737 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -606,8 +606,6 @@ def generate( sample_idx = self.n_tokens + len(tokens) - 1 draft_model = self.draft_model - candidates: List[int] = [] - candidates_idx = 0 tokens = list(tokens) # Eval and sample @@ -643,22 +641,16 @@ def generate( if tokens_or_none is not None: tokens.extend(tokens_or_none) - if candidates and token != candidates[candidates_idx]: - candidates = [] - candidates_idx = 0 - self.n_tokens = sample_idx - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) - break - else: - candidates_idx += 1 - - if draft_model is not None and sample_idx >= self.n_tokens: - candidates = [int(i) for i in draft_model(self._input_ids[:sample_idx])[ - : self._n_ctx - self.n_tokens - ]] - candidates_idx = 0 - tokens.extend(candidates) - print(candidates) + if sample_idx < self.n_tokens: + if token != self._input_ids[sample_idx]: + self.n_tokens = sample_idx + self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + break + + if draft_model is not None: + input_ids = np.concatenate([self._input_ids[:self.n_tokens], np.array(tokens)]) + tokens.extend([int(i) for i in draft_model(input_ids)]) + tokens = tokens[: self._n_ctx - self.n_tokens] def create_embedding( self, input: Union[str, List[str]], model: Optional[str] = None From fae83f21423d0a7ecf754aa311b1067100982c7f Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Jan 2024 17:12:24 -0500 Subject: [PATCH 06/16] Allow for later extensions to the LlamaDraftModel api --- llama_cpp/llama_speculative.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/llama_cpp/llama_speculative.py b/llama_cpp/llama_speculative.py index 4f21c0fc6f..c6a850a3a1 100644 --- a/llama_cpp/llama_speculative.py +++ b/llama_cpp/llama_speculative.py @@ -1,12 +1,16 @@ import abc +from typing import Any + import numpy as np import numpy.typing as npt class LlamaDraftModel(abc.ABC): @abc.abstractmethod - def __call__(self, input_ids: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]: + def __call__( + self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any + ) -> npt.NDArray[np.intc]: raise NotImplementedError() @@ -33,12 +37,10 @@ def find_candidate_pred_tokens( ngram = input_ids[-ngram_size:] # Create sliding windows of size ngram_size - windows = np.lib.stride_tricks.sliding_window_view( - input_ids, (ngram_size,) - ) + windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,)) # Convert ngram to an array for comparison - ngram_array = np.array(ngram)#.reshape(1, -1) + ngram_array = np.array(ngram) # .reshape(1, -1) # Find where the windows match the ngram matches = np.all(windows == ngram_array, axis=1) @@ -57,7 +59,9 @@ def find_candidate_pred_tokens( # If no match is found, return an empty array return np.array([], dtype=np.intc) - def __call__(self, input_ids: npt.NDArray[np.intc]) -> npt.NDArray[np.intc]: + def __call__( + self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any + ) -> npt.NDArray[np.intc]: return self.find_candidate_pred_tokens( input_ids=input_ids, max_ngram_size=self.max_ngram_size, From 346a6c524425ce4ba1fb25317f4f75acd1360904 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Jan 2024 21:59:39 -0500 Subject: [PATCH 07/16] Cleanup --- llama_cpp/llama.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 9372b27737..f0eabb44ff 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -637,10 +637,11 @@ def generate( ): return tokens_or_none = yield token - tokens = [token] + tokens.clear() + tokens.append(token) if tokens_or_none is not None: tokens.extend(tokens_or_none) - + if sample_idx < self.n_tokens: if token != self._input_ids[sample_idx]: self.n_tokens = sample_idx @@ -648,7 +649,9 @@ def generate( break if draft_model is not None: - input_ids = np.concatenate([self._input_ids[:self.n_tokens], np.array(tokens)]) + input_ids = np.concatenate( + [self._input_ids[: self.n_tokens], np.array(tokens)] + ) tokens.extend([int(i) for i in draft_model(input_ids)]) tokens = tokens[: self._n_ctx - self.n_tokens] From 84158376963f272a69292c5d65a05d0335139cd7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 23 Jan 2024 23:45:20 -0500 Subject: [PATCH 08/16] Adaptive candidate prediction --- llama_cpp/llama.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6014b37de7..deda8fb8ae 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -603,6 +603,8 @@ def generate( sample_idx = self.n_tokens + len(tokens) - 1 draft_model = self.draft_model tokens = list(tokens) + candidates_to_predict = 10 + candidates_all_correct = True # Eval and sample while True: @@ -642,14 +644,23 @@ def generate( if token != self._input_ids[sample_idx]: self.n_tokens = sample_idx self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + candidates_all_correct = False break if draft_model is not None: - input_ids = np.concatenate( - [self._input_ids[: self.n_tokens], np.array(tokens)] + if candidates_all_correct: + candidates_to_predict = min(10, candidates_to_predict + 2) + else: + candidates_to_predict = max(1, candidates_to_predict - 1) + self._input_ids[self.n_tokens : len(tokens)] = tokens + draft_tokens = draft_model(self._input_ids)[:candidates_to_predict] + candidates_to_predict = len(draft_tokens) + candidates_all_correct = True + tokens.extend( + draft_tokens.astype(int)[ + : self._n_ctx - self.n_tokens - len(tokens) + ] ) - tokens.extend([int(i) for i in draft_model(input_ids)]) - tokens = tokens[: self._n_ctx - self.n_tokens] def create_embedding( self, input: Union[str, List[str]], model: Optional[str] = None From c363eee99b0b0e1157716d41c7b9c42786ab36c8 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 24 Jan 2024 11:40:19 -0500 Subject: [PATCH 09/16] Update implementation to match hf transformers --- llama_cpp/llama_speculative.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/llama_cpp/llama_speculative.py b/llama_cpp/llama_speculative.py index c6a850a3a1..39dfb903ba 100644 --- a/llama_cpp/llama_speculative.py +++ b/llama_cpp/llama_speculative.py @@ -17,30 +17,24 @@ def __call__( class LlamaPromptLookupDecoding(LlamaDraftModel): """Based on https://github.com/apoorvumang/prompt-lookup-decoding""" - def __init__(self, max_ngram_size: int = 3, num_pred_tokens: int = 10): + def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10): self.max_ngram_size = max_ngram_size self.num_pred_tokens = num_pred_tokens @staticmethod def find_candidate_pred_tokens( input_ids: npt.NDArray[np.intc], - max_ngram_size: int = 3, - num_pred_tokens: int = 10, + max_ngram_size: int, + num_pred_tokens: int, ): input_length = input_ids.shape[0] - if input_length < max_ngram_size: - return np.array([], dtype=np.intc) - - for ngram_size in range(max_ngram_size, 0, -1): - # Extract the last n tokens as our search ngram - ngram = input_ids[-ngram_size:] - + for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1): # Create sliding windows of size ngram_size windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,)) # Convert ngram to an array for comparison - ngram_array = np.array(ngram) # .reshape(1, -1) + ngram_array = input_ids[-ngram_size:] # Find where the windows match the ngram matches = np.all(windows == ngram_array, axis=1) @@ -52,8 +46,9 @@ def find_candidate_pred_tokens( for idx in match_indices: start_idx = idx + ngram_size end_idx = start_idx + num_pred_tokens - # Ensure we don't go beyond the length of input_ids and avoid self-match - if end_idx <= input_length and start_idx < input_length - ngram_size: + end_idx = min(end_idx, input_length) + + if start_idx < end_idx: return input_ids[start_idx:end_idx] # If no match is found, return an empty array From 5ab599911925d3e3eaffaa887915eec50540435b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 24 Jan 2024 11:40:29 -0500 Subject: [PATCH 10/16] Tuning --- llama_cpp/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 8292aa73d3..c06e35d05e 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -649,7 +649,7 @@ def generate( if draft_model is not None: if candidates_all_correct: - candidates_to_predict = min(10, candidates_to_predict + 2) + candidates_to_predict = candidates_to_predict + 2 else: candidates_to_predict = max(1, candidates_to_predict - 1) self._input_ids[self.n_tokens : len(tokens)] = tokens From f39690c601acbf80e234ed5b3ff9a1f4d20f689b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 Jan 2024 11:31:20 -0500 Subject: [PATCH 11/16] Fix bug where last token was not used for ngram prediction --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c06e35d05e..05bb42a6ab 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -652,8 +652,8 @@ def generate( candidates_to_predict = candidates_to_predict + 2 else: candidates_to_predict = max(1, candidates_to_predict - 1) - self._input_ids[self.n_tokens : len(tokens)] = tokens - draft_tokens = draft_model(self._input_ids)[:candidates_to_predict] + self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens + draft_tokens = draft_model(self.input_ids[:self.n_tokens + len(tokens)])[:candidates_to_predict] candidates_to_predict = len(draft_tokens) candidates_all_correct = True tokens.extend( From c6013e2cb4cc6bf49df17e467ad16eff958a7376 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 Jan 2024 16:28:02 -0500 Subject: [PATCH 12/16] Remove heuristic for num_pred_tokens (no benefit) --- llama_cpp/llama.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 05bb42a6ab..01bbb88a40 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -604,7 +604,6 @@ def generate( draft_model = self.draft_model tokens = list(tokens) candidates_to_predict = 10 - candidates_all_correct = True # Eval and sample while True: @@ -644,18 +643,12 @@ def generate( if token != self._input_ids[sample_idx]: self.n_tokens = sample_idx self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) - candidates_all_correct = False break if draft_model is not None: - if candidates_all_correct: - candidates_to_predict = candidates_to_predict + 2 - else: - candidates_to_predict = max(1, candidates_to_predict - 1) self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens draft_tokens = draft_model(self.input_ids[:self.n_tokens + len(tokens)])[:candidates_to_predict] candidates_to_predict = len(draft_tokens) - candidates_all_correct = True tokens.extend( draft_tokens.astype(int)[ : self._n_ctx - self.n_tokens - len(tokens) From 4f946b08becee637733519a027f77e6d2a280c86 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 31 Jan 2024 13:42:44 -0500 Subject: [PATCH 13/16] fix: n_candidates bug. --- llama_cpp/llama.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 415510a71a..2fb8984ec1 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -638,7 +638,6 @@ def generate( sample_idx = self.n_tokens + len(tokens) - 1 draft_model = self.draft_model tokens = list(tokens) - candidates_to_predict = 10 # Eval and sample while True: @@ -682,8 +681,7 @@ def generate( if draft_model is not None: self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens - draft_tokens = draft_model(self.input_ids[:self.n_tokens + len(tokens)])[:candidates_to_predict] - candidates_to_predict = len(draft_tokens) + draft_tokens = draft_model(self.input_ids[:self.n_tokens + len(tokens)]) tokens.extend( draft_tokens.astype(int)[ : self._n_ctx - self.n_tokens - len(tokens) From df93d1d622eaa3ab283406d14ba2cc1e91f520ba Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 31 Jan 2024 13:43:08 -0500 Subject: [PATCH 14/16] Add draft_model_num_pred_tokens server setting --- llama_cpp/server/model.py | 4 +++- llama_cpp/server/settings.py | 4 ++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index 32e841686e..925ab99b73 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -95,7 +95,9 @@ def load_llama_from_model_settings(settings: ModelSettings) -> llama_cpp.Llama: draft_model = None if settings.draft_model is not None: - draft_model = llama_speculative.LlamaPromptLookupDecoding() + draft_model = llama_speculative.LlamaPromptLookupDecoding( + num_pred_tokens=settings.draft_model_num_pred_tokens + ) kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None if settings.kv_overrides is not None: diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 6904b50f26..60f3eeca23 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -148,6 +148,10 @@ class ModelSettings(BaseSettings): default=None, description="Method to use for speculative decoding. One of (prompt-lookup-decoding).", ) + draft_model_num_pred_tokens: int = Field( + default=10, + description="Number of tokens to predict using the draft model.", + ) # Misc verbose: bool = Field( default=True, description="Whether to print debug information." From 995d40c0f09ddf9633aae0bdb177adac2d0113ef Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 31 Jan 2024 13:54:25 -0500 Subject: [PATCH 15/16] Cleanup --- llama_cpp/llama.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2fb8984ec1..f00fd4fc51 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -636,7 +636,6 @@ def generate( grammar.reset() sample_idx = self.n_tokens + len(tokens) - 1 - draft_model = self.draft_model tokens = list(tokens) # Eval and sample @@ -673,15 +672,14 @@ def generate( if tokens_or_none is not None: tokens.extend(tokens_or_none) - if sample_idx < self.n_tokens: - if token != self._input_ids[sample_idx]: - self.n_tokens = sample_idx - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) - break + if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]: + self.n_tokens = sample_idx + self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + break - if draft_model is not None: + if self.draft_model is not None: self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens - draft_tokens = draft_model(self.input_ids[:self.n_tokens + len(tokens)]) + draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)]) tokens.extend( draft_tokens.astype(int)[ : self._n_ctx - self.n_tokens - len(tokens) From 291eadc892b5f046a1cf0757b1de76c66b4640a4 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 31 Jan 2024 14:07:44 -0500 Subject: [PATCH 16/16] Update README --- README.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/README.md b/README.md index 0a77bbdaa3..4131bb3d44 100644 --- a/README.md +++ b/README.md @@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process ) ``` +### Speculative Decoding + +`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model. + +The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class. + +Just pass this as a draft model to the `Llama` class during initialization. + +```python +from llama_cpp import Llama +from llama_cpp.llama_speculative import LlamaPromptLookupDecoding + +llama = Llama( + model_path="path/to/model.gguf", + draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines. +) +``` + ### Adjusting the Context Window The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.