parent
bfcdfd5b7e
commit
78280d44bf
|
|
@ -40,66 +40,25 @@ class PythonModelService:
|
||||||
.to(device)
|
.to(device)
|
||||||
.eval()
|
.eval()
|
||||||
)
|
)
|
||||||
self.stopping_criteria_mappings = {}
|
|
||||||
|
|
||||||
def generate(self, request: CompletionRequest) -> List[Choice]:
|
def generate(self, request: CompletionRequest) -> List[Choice]:
|
||||||
# FIXME(meng): read preset from request.
|
# FIXME(meng): read preset from request.
|
||||||
preset_name = "python"
|
preset_name = "python"
|
||||||
preset = LanguagePresets[preset_name]
|
preset = LanguagePresets[preset_name]
|
||||||
|
|
||||||
stopping_criteria_list = self.stopping_criteria_for_preset(preset_name)
|
|
||||||
|
|
||||||
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
|
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
res = self.model.generate(
|
res = self.model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
max_length=preset.max_length,
|
max_length=preset.max_length,
|
||||||
stopping_criteria=stopping_criteria_list,
|
|
||||||
)
|
)
|
||||||
output_ids = res[0][len(input_ids[0]) :]
|
output_ids = res[0][len(input_ids[0]) :]
|
||||||
text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words)
|
text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words)
|
||||||
return [Choice(index=0, text=text)]
|
return [Choice(index=0, text=text)]
|
||||||
|
|
||||||
def stopping_criteria_for_preset(self, name: str) -> StoppingCriteriaList:
|
|
||||||
return StoppingCriteriaList(
|
|
||||||
[
|
|
||||||
StopWordsIdsCriteria(
|
|
||||||
[self.tokenizer.encode(x) for x in LanguagePresets[name].stop_words]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
choices = self.generate(request)
|
choices = self.generate(request)
|
||||||
return CompletionResponse(
|
return CompletionResponse(
|
||||||
id=random_completion_id(), created=int(time.time()), choices=choices
|
id=random_completion_id(), created=int(time.time()), choices=choices
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class StopWordsIdsCriteria(StoppingCriteria):
|
|
||||||
def __init__(self, stop_words_ids: List[str]):
|
|
||||||
self.stop_words_ids = stop_words_ids
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
|
|
||||||
) -> bool:
|
|
||||||
if len(input_ids) != 1:
|
|
||||||
raise ValueError("Only 1-length list is handled")
|
|
||||||
|
|
||||||
# FIXME(meng): trie based lookup.
|
|
||||||
tokens = input_ids[0]
|
|
||||||
for stop_word in self.stop_words_ids:
|
|
||||||
if len(tokens) < len(stop_word):
|
|
||||||
continue
|
|
||||||
|
|
||||||
matched = True
|
|
||||||
for i in range(len(stop_word)):
|
|
||||||
if tokens[i - len(stop_word)] != stop_word[i]:
|
|
||||||
matched = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if matched:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue