tabby/crates/ctranslate2-bindings/ctranslate2/python/tests/test_opennmt_py.py

117 lines
4.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import os
import pytest
import test_utils
import ctranslate2
@test_utils.skip_on_windows
def test_opennmt_py_model_conversion(tmp_dir):
model_path = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"opennmt_py",
"aren_7000.pt",
)
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch([["آ", "ت", "ز", "م", "و", "ن"]])
assert output[0].hypotheses[0] == ["a", "t", "z", "u", "m", "o", "n"]
@test_utils.skip_on_windows
def test_opennmt_py_relative_transformer(tmp_dir):
model_path = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"opennmt_py",
"aren_relative_6000.pt",
)
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
translator = ctranslate2.Translator(output_dir)
output = translator.translate_batch(
[["آ", "ت", "ز", "م", "و", "ن"], ["آ", "ر", "ث", "ر"]]
)
assert output[0].hypotheses[0] == ["a", "t", "z", "o", "m", "o", "n"]
assert output[1].hypotheses[0] == ["a", "r", "t", "h", "e", "r"]
@test_utils.skip_on_windows
@pytest.mark.parametrize(
"filename", ["aren_features_concat_10000.pt", "aren_features_sum_10000.pt"]
)
def test_opennmt_py_source_features(tmp_dir, filename):
model_path = os.path.join(
test_utils.get_data_dir(),
"models",
"transliteration-aren-all",
"opennmt_py",
filename,
)
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
assert os.path.isfile(os.path.join(output_dir, "source_1_vocabulary.txt"))
assert os.path.isfile(os.path.join(output_dir, "source_2_vocabulary.txt"))
source = [
["آ", "ت", "ز", "م", "و", "ن"],
["آ", "ت", "ش", "ي", "س", "و", "ن"],
]
source_features = [
["0", "1", "2", "3", "4", "5"],
["0", "1", "2", "3", "4", "5", "6"],
]
expected_target = [
["a", "t", "z", "m", "o", "n"],
["a", "c", "h", "i", "s", "o", "n"],
]
source_w_features = []
for tokens, features in zip(source, source_features):
source_w_features.append(["%s%s" % pair for pair in zip(tokens, features)])
translator = ctranslate2.Translator(output_dir)
with pytest.raises(ValueError, match="features"):
translator.translate_batch(source)
outputs = translator.translate_batch(source_w_features)
for output, expected_hypothesis in zip(outputs, expected_target):
assert output.hypotheses[0] == expected_hypothesis
input_path = str(tmp_dir.join("input.txt"))
output_path = str(tmp_dir.join("output.txt"))
test_utils.write_tokens(source, input_path)
with pytest.raises(ValueError, match="features"):
translator.translate_file(input_path, output_path)
test_utils.write_tokens(source_w_features, input_path)
translator.translate_file(input_path, output_path)
with open(output_path) as output_file:
for line, expected_hypothesis in zip(output_file, expected_target):
assert line.strip().split() == expected_hypothesis
@test_utils.skip_on_windows
def test_opennmt_py_transformer_lm(tmp_dir):
model_path = os.path.join(test_utils.get_data_dir(), "models", "pi_lm_step_5000.pt")
if not os.path.exists(model_path):
pytest.skip("Checkpoint file is not available")
converter = ctranslate2.converters.OpenNMTPyConverter(model_path)
output_dir = str(tmp_dir.join("ctranslate2_model"))
converter.convert(output_dir)
generator = ctranslate2.Generator(output_dir)
results = generator.generate_batch([["<s>", "3", ".", "1", "4"]], max_length=12)
assert "".join(results[0].sequences[0]) == "3.1415926535"