tabby/crates/ctranslate2-bindings/ctranslate2/python/tests/test_opennmt_py.py

117 lines
4.0 KiB
Python
Raw Permalink Normal View History

import os
import pytest
import test_utils
import ctranslate2
@test_utils.skip_on_windows
def test_opennmt_py_model_conversion(tmp_dir):
model_path = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"opennmt_py",
"aren_7000.pt",
)
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch([["آ", "ت", "ز", "م", "و", "ن"]])
assert output[0].hypotheses[0] == ["a", "t", "z", "u", "m", "o", "n"]
@test_utils.skip_on_windows
def test_opennmt_py_relative_transformer(tmp_dir):
model_path = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"opennmt_py",
"aren_relative_6000.pt",
)
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch(
[["آ", "ت", "ز", "م", "و", "ن"], ["آ", "ر", "ث", "ر"]]
)
assert output[0].hypotheses[0] == ["a", "t", "z", "o", "m", "o", "n"]
assert output[1].hypotheses[0] == ["a", "r", "t", "h", "e", "r"]
@test_utils.skip_on_windows
@pytest.mark.parametrize(
"filename", ["aren_features_concat_10000.pt", "aren_features_sum_10000.pt"]
)
def test_opennmt_py_source_features(tmp_dir, filename):
model_path = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"opennmt_py",
filename,
)
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt"))
source = [
["آ", "ت", "ز", "م", "و", "ن"],
["آ", "ت", "ش", "ي", "س", "و", "ن"],
]
source_features = [
["0", "1", "2", "3", "4", "5"],
["0", "1", "2", "3", "4", "5", "6"],
]
expected_target = [
["a", "t", "z", "m", "o", "n"],
["a", "c", "h", "i", "s", "o", "n"],
]
source_w_features = []
for tokens, features in zip(source, source_features):
source_w_features.append(["%s%s" % pair for pair in zip(tokens, features)])
translator = ctranslate2.Translator(output_dir)
with pytest.raises(ValueError, match="features"):
translator.translate_batch(source)
outputs = translator.translate_batch(source_w_features)
for output, expected_hypothesis in zip(outputs, expected_target):
assert output.hypotheses[0] == expected_hypothesis
input_path = str(tmp_dir.join("input.txt"))
output_path = str(tmp_dir.join("output.txt"))
test_utils.write_tokens(source, input_path)
with pytest.raises(ValueError, match="features"):
translator.translate_file(input_path, output_path)
test_utils.write_tokens(source_w_features, input_path)
translator.translate_file(input_path, output_path)
with open(output_path) as output_file:
for line, expected_hypothesis in zip(output_file, expected_target):
assert line.strip().split() == expected_hypothesis
@test_utils.skip_on_windows
def test_opennmt_py_transformer_lm(tmp_dir):
model_path = os.path.join(test_utils.get_data_dir(), "models", "pi_lm_step_5000.pt")
if not os.path.exists(model_path):
pytest.skip("Checkpoint file is not available")
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
generator = ctranslate2.Generator(output_dir)
results = generator.generate_batch([["<s>", "3", ".", "1", "4"]], max_length=12)
assert "".join(results[0].sequences[0]) == "3.1415926535"