feat: change language field to string to make simplier (#100)

add-tracing
Meng Zhang 2023-04-13 16:25:13 +08:00 committed by GitHub
parent a9b00b1450
commit 83cecc9279
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 4 deletions

View File

@ -43,7 +43,10 @@ class PythonModelService:
) )
def generate(self, request: CompletionRequest) -> List[Choice]: def generate(self, request: CompletionRequest) -> List[Choice]:
preset = LanguagePresets[request.language] preset = LanguagePresets.get(data.language, None)
if preset is None:
return []
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to( input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
self.device self.device
) )

View File

@ -40,7 +40,9 @@ class TritonService:
np_type = np.uint32 np_type = np.uint32
model_name = "fastertransformer" model_name = "fastertransformer"
preset = LanguagePresets[data.language] preset = LanguagePresets.get(data.language, None)
if preset is None:
return []
if self.rewriter: if self.rewriter:
prompt = self.rewriter(preset, data.prompt) prompt = self.rewriter(preset, data.prompt)

View File

@ -18,7 +18,7 @@ class Language(str, Enum):
class CompletionRequest(BaseModel): class CompletionRequest(BaseModel):
language: Language = Field( language: str = Field(
example=Language.PYTHON, example=Language.PYTHON,
default=Language.UNKNOWN, default=Language.UNKNOWN,
description="Language for completion request", description="Language for completion request",
@ -48,7 +48,7 @@ class Event(BaseModel):
class CompletionEvent(Event): class CompletionEvent(Event):
id: str id: str
language: Language language: str
prompt: str prompt: str
created: int created: int
choices: List[Choice] choices: List[Choice]