style: stopwords -> stop_words

add-more-languages
Meng Zhang 2023-04-02 11:26:43 +08:00
parent 78280d44bf
commit 82103e7280
3 changed files with 7 additions and 5 deletions

View File

@ -11,7 +11,7 @@ from transformers import (
from ..models import Choice, CompletionRequest, CompletionResponse
from .language_presets import LanguagePresets
from .utils import random_completion_id, trim_with_stopwords
from .utils import random_completion_id, trim_with_stop_words
class PythonModelService:
@ -54,7 +54,9 @@ class PythonModelService:
max_length=preset.max_length,
)
output_ids = res[0][len(input_ids[0]) :]
text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words)
text = trim_with_stop_words(
self.tokenizer.decode(output_ids), preset.stop_words
)
return [Choice(index=0, text=text)]
def __call__(self, request: CompletionRequest) -> CompletionResponse:

View File

@ -8,7 +8,7 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
from ..models import Choice, CompletionRequest, CompletionResponse
from .language_presets import LanguagePresets
from .utils import random_completion_id, trim_with_stopwords
from .utils import random_completion_id, trim_with_stop_words
class TritonService:
@ -69,7 +69,7 @@ class TritonService:
self.tokenizer.decode(out[prompt_len : prompt_len + g])
for g, out in zip(gen_len, output_data)
]
trimmed = [trim_with_stopwords(d, preset.stop_words) for d in decoded]
trimmed = [trim_with_stop_words(d, preset.stop_words) for d in decoded]
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
def __call__(self, data: CompletionRequest) -> CompletionResponse:

View File

@ -8,7 +8,7 @@ def random_completion_id():
)
def trim_with_stopwords(output: str, stopwords: list) -> str:
def trim_with_stop_words(output: str, stopwords: list) -> str:
for w in sorted(stopwords, key=len, reverse=True):
if output.endswith(w):
output = output[: -len(w)]