style: stopwords -> stop_words
parent
78280d44bf
commit
82103e7280
|
|
@ -11,7 +11,7 @@ from transformers import (
|
|||
|
||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||
from .language_presets import LanguagePresets
|
||||
from .utils import random_completion_id, trim_with_stopwords
|
||||
from .utils import random_completion_id, trim_with_stop_words
|
||||
|
||||
|
||||
class PythonModelService:
|
||||
|
|
@ -54,7 +54,9 @@ class PythonModelService:
|
|||
max_length=preset.max_length,
|
||||
)
|
||||
output_ids = res[0][len(input_ids[0]) :]
|
||||
text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words)
|
||||
text = trim_with_stop_words(
|
||||
self.tokenizer.decode(output_ids), preset.stop_words
|
||||
)
|
||||
return [Choice(index=0, text=text)]
|
||||
|
||||
def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
|||
|
||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||
from .language_presets import LanguagePresets
|
||||
from .utils import random_completion_id, trim_with_stopwords
|
||||
from .utils import random_completion_id, trim_with_stop_words
|
||||
|
||||
|
||||
class TritonService:
|
||||
|
|
@ -69,7 +69,7 @@ class TritonService:
|
|||
self.tokenizer.decode(out[prompt_len : prompt_len + g])
|
||||
for g, out in zip(gen_len, output_data)
|
||||
]
|
||||
trimmed = [trim_with_stopwords(d, preset.stop_words) for d in decoded]
|
||||
trimmed = [trim_with_stop_words(d, preset.stop_words) for d in decoded]
|
||||
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
|
||||
|
||||
def __call__(self, data: CompletionRequest) -> CompletionResponse:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ def random_completion_id():
|
|||
)
|
||||
|
||||
|
||||
def trim_with_stopwords(output: str, stopwords: list) -> str:
|
||||
def trim_with_stop_words(output: str, stopwords: list) -> str:
|
||||
for w in sorted(stopwords, key=len, reverse=True):
|
||||
if output.endswith(w):
|
||||
output = output[: -len(w)]
|
||||
|
|
|
|||
Loading…
Reference in New Issue