parent
bfcdfd5b7e
commit
78280d44bf
|
|
@ -40,66 +40,25 @@ class PythonModelService:
|
|||
.to(device)
|
||||
.eval()
|
||||
)
|
||||
self.stopping_criteria_mappings = {}
|
||||
|
||||
def generate(self, request: CompletionRequest) -> List[Choice]:
|
||||
# FIXME(meng): read preset from request.
|
||||
preset_name = "python"
|
||||
preset = LanguagePresets[preset_name]
|
||||
|
||||
stopping_criteria_list = self.stopping_criteria_for_preset(preset_name)
|
||||
|
||||
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
res = self.model.generate(
|
||||
input_ids,
|
||||
max_length=preset.max_length,
|
||||
stopping_criteria=stopping_criteria_list,
|
||||
)
|
||||
output_ids = res[0][len(input_ids[0]) :]
|
||||
text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words)
|
||||
return [Choice(index=0, text=text)]
|
||||
|
||||
def stopping_criteria_for_preset(self, name: str) -> StoppingCriteriaList:
|
||||
return StoppingCriteriaList(
|
||||
[
|
||||
StopWordsIdsCriteria(
|
||||
[self.tokenizer.encode(x) for x in LanguagePresets[name].stop_words]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
||||
choices = self.generate(request)
|
||||
return CompletionResponse(
|
||||
id=random_completion_id(), created=int(time.time()), choices=choices
|
||||
)
|
||||
|
||||
|
||||
class StopWordsIdsCriteria(StoppingCriteria):
|
||||
def __init__(self, stop_words_ids: List[str]):
|
||||
self.stop_words_ids = stop_words_ids
|
||||
|
||||
def __call__(
|
||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
||||
) -> bool:
|
||||
if len(input_ids) != 1:
|
||||
raise ValueError("Only 1-length list is handled")
|
||||
|
||||
# FIXME(meng): trie based lookup.
|
||||
tokens = input_ids[0]
|
||||
for stop_word in self.stop_words_ids:
|
||||
if len(tokens) < len(stop_word):
|
||||
continue
|
||||
|
||||
matched = True
|
||||
for i in range(len(stop_word)):
|
||||
if tokens[i - len(stop_word)] != stop_word[i]:
|
||||
matched = False
|
||||
break
|
||||
|
||||
if matched:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in New Issue