tabby/crates/ctranslate2-bindings/ctranslate2/python/tests/test_fairseq.py

99 lines
3.1 KiB
Python

import os
import pytest
import test_utils
import ctranslate2
@test_utils.skip_on_windows
def test_fairseq_model_conversion(tmp_dir):
data_dir = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"fairseq",
)
converter = ctranslate2.converters.FairseqConverter(
os.path.join(data_dir, "model.pt"), data_dir
)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch([["آ", "ت", "ز", "م", "و", "ن"]])
assert output[0].hypotheses[0] == ["a", "t", "z", "m", "o", "n"]
output = translator.score_batch(
[["آ", "ت", "ز", "م", "و", "ن"]], [["a", "t", "z", "m", "o", "n"]]
)
expected_log_probs = [
-0.036438941955566406,
-0.20800018310546875,
-0.06881046295166016,
-0.5400338172912598,
-0.12033462524414062,
-0.1102609634399414,
-0.044970035552978516,
]
assert output[0].tokens == ["a", "t", "z", "m", "o", "n", "</s>"]
assert output[0].log_probs == pytest.approx(expected_log_probs, 1e-5)
@test_utils.skip_on_windows
def test_fairseq_user_start_token(tmp_dir):
data_dir = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"fairseq",
)
converter = ctranslate2.converters.FairseqConverter(
os.path.join(data_dir, "model.pt"), data_dir, no_default_special_tokens=True
)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
translator = ctranslate2.Translator(output_dir)
tokens = ["آ", "ت", "ز", "م", "و", "ن"]
tokens += ["</s>"]
with pytest.raises(ValueError, match="start token"):
translator.translate_batch([tokens])
output = translator.translate_batch([tokens], target_prefix=[["</s>"]])
assert output[0].hypotheses[0] == ["a", "t", "z", "m", "o", "n"]
output = translator.translate_batch([["</s>"]], target_prefix=[["</s>"]])
assert output[0].hypotheses[0] == []
output = translator.translate_batch(
[["</s>"]],
target_prefix=[["</s>", "oovtoken"]],
return_scores=True,
return_attention=True,
)
assert output[0].hypotheses[0] == ["<unk>"]
assert output[0].scores[0] == 0
assert output[0].attention[0] == [[0]]
output = translator.score_batch([tokens], [["</s>", "a", "t", "z", "m", "o", "n"]])
expected_log_probs = [
-0.036438941955566406,
-0.20800018310546875,
-0.06881046295166016,
-0.5400338172912598,
-0.12033462524414062,
-0.1102609634399414,
-0.044970035552978516,
]
assert output[0].tokens == ["a", "t", "z", "m", "o", "n", "</s>"]
assert output[0].log_probs == pytest.approx(expected_log_probs, 1e-5)
# In this mode, an empty target is not enough tokens to run the score so the output is empty.
output = translator.score_batch([tokens], [[]])
assert not output[0].tokens
assert not output[0].log_probs