From 78280d44bf9cb02425f5f4cf2777b4aafbe9eca8 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 30 Mar 2023 14:52:04 +0800 Subject: [PATCH] Revert stop words implementation in python #33 --- tabby/server/backend/python.py | 41 ---------------------------------- 1 file changed, 41 deletions(-) diff --git a/tabby/server/backend/python.py b/tabby/server/backend/python.py index 5d8a24d..a521e59 100644 --- a/tabby/server/backend/python.py +++ b/tabby/server/backend/python.py @@ -40,66 +40,25 @@ class PythonModelService: .to(device) .eval() ) - self.stopping_criteria_mappings = {} def generate(self, request: CompletionRequest) -> List[Choice]: # FIXME(meng): read preset from request. preset_name = "python" preset = LanguagePresets[preset_name] - stopping_criteria_list = self.stopping_criteria_for_preset(preset_name) - input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to( self.device ) res = self.model.generate( input_ids, max_length=preset.max_length, - stopping_criteria=stopping_criteria_list, ) output_ids = res[0][len(input_ids[0]) :] text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words) return [Choice(index=0, text=text)] - def stopping_criteria_for_preset(self, name: str) -> StoppingCriteriaList: - return StoppingCriteriaList( - [ - StopWordsIdsCriteria( - [self.tokenizer.encode(x) for x in LanguagePresets[name].stop_words] - ) - ] - ) - def __call__(self, request: CompletionRequest) -> CompletionResponse: choices = self.generate(request) return CompletionResponse( id=random_completion_id(), created=int(time.time()), choices=choices ) - - -class StopWordsIdsCriteria(StoppingCriteria): - def __init__(self, stop_words_ids: List[str]): - self.stop_words_ids = stop_words_ids - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs - ) -> bool: - if len(input_ids) != 1: - raise ValueError("Only 1-length list is handled") - - # FIXME(meng): trie based lookup. - tokens = input_ids[0] - for stop_word in self.stop_words_ids: - if len(tokens) < len(stop_word): - continue - - matched = True - for i in range(len(stop_word)): - if tokens[i - len(stop_word)] != stop_word[i]: - matched = False - break - - if matched: - return True - - return False