feat: change language field to string to make simplier (#100)
parent
a9b00b1450
commit
83cecc9279
|
|
@ -43,7 +43,10 @@ class PythonModelService:
|
|||
)
|
||||
|
||||
def generate(self, request: CompletionRequest) -> List[Choice]:
|
||||
preset = LanguagePresets[request.language]
|
||||
preset = LanguagePresets.get(data.language, None)
|
||||
if preset is None:
|
||||
return []
|
||||
|
||||
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,9 @@ class TritonService:
|
|||
np_type = np.uint32
|
||||
model_name = "fastertransformer"
|
||||
|
||||
preset = LanguagePresets[data.language]
|
||||
preset = LanguagePresets.get(data.language, None)
|
||||
if preset is None:
|
||||
return []
|
||||
|
||||
if self.rewriter:
|
||||
prompt = self.rewriter(preset, data.prompt)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class Language(str, Enum):
|
|||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
language: Language = Field(
|
||||
language: str = Field(
|
||||
example=Language.PYTHON,
|
||||
default=Language.UNKNOWN,
|
||||
description="Language for completion request",
|
||||
|
|
@ -48,7 +48,7 @@ class Event(BaseModel):
|
|||
|
||||
class CompletionEvent(Event):
|
||||
id: str
|
||||
language: Language
|
||||
language: str
|
||||
prompt: str
|
||||
created: int
|
||||
choices: List[Choice]
|
||||
|
|
|
|||
Loading…
Reference in New Issue