feat: change language field to string to make simplier (#100)
parent
a9b00b1450
commit
83cecc9279
|
|
@ -43,7 +43,10 @@ class PythonModelService:
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(self, request: CompletionRequest) -> List[Choice]:
|
def generate(self, request: CompletionRequest) -> List[Choice]:
|
||||||
preset = LanguagePresets[request.language]
|
preset = LanguagePresets.get(data.language, None)
|
||||||
|
if preset is None:
|
||||||
|
return []
|
||||||
|
|
||||||
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
|
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
|
||||||
self.device
|
self.device
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,9 @@ class TritonService:
|
||||||
np_type = np.uint32
|
np_type = np.uint32
|
||||||
model_name = "fastertransformer"
|
model_name = "fastertransformer"
|
||||||
|
|
||||||
preset = LanguagePresets[data.language]
|
preset = LanguagePresets.get(data.language, None)
|
||||||
|
if preset is None:
|
||||||
|
return []
|
||||||
|
|
||||||
if self.rewriter:
|
if self.rewriter:
|
||||||
prompt = self.rewriter(preset, data.prompt)
|
prompt = self.rewriter(preset, data.prompt)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class Language(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class CompletionRequest(BaseModel):
|
class CompletionRequest(BaseModel):
|
||||||
language: Language = Field(
|
language: str = Field(
|
||||||
example=Language.PYTHON,
|
example=Language.PYTHON,
|
||||||
default=Language.UNKNOWN,
|
default=Language.UNKNOWN,
|
||||||
description="Language for completion request",
|
description="Language for completion request",
|
||||||
|
|
@ -48,7 +48,7 @@ class Event(BaseModel):
|
||||||
|
|
||||||
class CompletionEvent(Event):
|
class CompletionEvent(Event):
|
||||||
id: str
|
id: str
|
||||||
language: Language
|
language: str
|
||||||
prompt: str
|
prompt: str
|
||||||
created: int
|
created: int
|
||||||
choices: List[Choice]
|
choices: List[Choice]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue