style: stopwords -> stop_words
parent
78280d44bf
commit
82103e7280
|
|
@ -11,7 +11,7 @@ from transformers import (
|
||||||
|
|
||||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||||
from .language_presets import LanguagePresets
|
from .language_presets import LanguagePresets
|
||||||
from .utils import random_completion_id, trim_with_stopwords
|
from .utils import random_completion_id, trim_with_stop_words
|
||||||
|
|
||||||
|
|
||||||
class PythonModelService:
|
class PythonModelService:
|
||||||
|
|
@ -54,7 +54,9 @@ class PythonModelService:
|
||||||
max_length=preset.max_length,
|
max_length=preset.max_length,
|
||||||
)
|
)
|
||||||
output_ids = res[0][len(input_ids[0]) :]
|
output_ids = res[0][len(input_ids[0]) :]
|
||||||
text = trim_with_stopwords(self.tokenizer.decode(output_ids), preset.stop_words)
|
text = trim_with_stop_words(
|
||||||
|
self.tokenizer.decode(output_ids), preset.stop_words
|
||||||
|
)
|
||||||
return [Choice(index=0, text=text)]
|
return [Choice(index=0, text=text)]
|
||||||
|
|
||||||
def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
||||||
|
|
||||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||||
from .language_presets import LanguagePresets
|
from .language_presets import LanguagePresets
|
||||||
from .utils import random_completion_id, trim_with_stopwords
|
from .utils import random_completion_id, trim_with_stop_words
|
||||||
|
|
||||||
|
|
||||||
class TritonService:
|
class TritonService:
|
||||||
|
|
@ -69,7 +69,7 @@ class TritonService:
|
||||||
self.tokenizer.decode(out[prompt_len : prompt_len + g])
|
self.tokenizer.decode(out[prompt_len : prompt_len + g])
|
||||||
for g, out in zip(gen_len, output_data)
|
for g, out in zip(gen_len, output_data)
|
||||||
]
|
]
|
||||||
trimmed = [trim_with_stopwords(d, preset.stop_words) for d in decoded]
|
trimmed = [trim_with_stop_words(d, preset.stop_words) for d in decoded]
|
||||||
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
|
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
|
||||||
|
|
||||||
def __call__(self, data: CompletionRequest) -> CompletionResponse:
|
def __call__(self, data: CompletionRequest) -> CompletionResponse:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ def random_completion_id():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def trim_with_stopwords(output: str, stopwords: list) -> str:
|
def trim_with_stop_words(output: str, stopwords: list) -> str:
|
||||||
for w in sorted(stopwords, key=len, reverse=True):
|
for w in sorted(stopwords, key=len, reverse=True):
|
||||||
if output.endswith(w):
|
if output.endswith(w):
|
||||||
output = output[: -len(w)]
|
output = output[: -len(w)]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue