Update
parent
0d8d7097be
commit
fbcab616d7
|
|
@ -25,10 +25,12 @@ class TritonService:
|
||||||
)
|
)
|
||||||
|
|
||||||
def generate(self, data: CompletionsRequest) -> List[Choice]:
|
def generate(self, data: CompletionsRequest) -> List[Choice]:
|
||||||
|
# FIXME(meng): Make following vars configurable
|
||||||
n = 1
|
n = 1
|
||||||
np_type = np.uint32
|
np_type = np.uint32
|
||||||
max_tokens = 128
|
max_tokens = 128
|
||||||
model_name = "fastertransformer"
|
model_name = "fastertransformer"
|
||||||
|
stop_words = ["\n\n"]
|
||||||
|
|
||||||
prompt = data.prompt
|
prompt = data.prompt
|
||||||
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
||||||
|
|
@ -39,7 +41,6 @@ class TritonService:
|
||||||
prompt_tokens: int = input_len[0][0]
|
prompt_tokens: int = input_len[0][0]
|
||||||
output_len = np.ones_like(input_len).astype(np_type) * max_tokens
|
output_len = np.ones_like(input_len).astype(np_type) * max_tokens
|
||||||
|
|
||||||
stop_words = ["\n\n"]
|
|
||||||
stop_word_list = np.repeat(
|
stop_word_list = np.repeat(
|
||||||
to_word_list_format([stop_words], self.tokenizer),
|
to_word_list_format([stop_words], self.tokenizer),
|
||||||
input_start_ids.shape[0],
|
input_start_ids.shape[0],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue