tabby/preprocess/preprocess_project.py

80 lines
1.9 KiB
Python

import glob
import json
import os
import filters
import metrics
from args import PreprocessProjectArgs
from datasets import Dataset
from transformers import HfArgumentParser
def parse_args():
parser = HfArgumentParser(PreprocessProjectArgs)
return parser.parse_args()
def read_languages_to_file_extensions():
path = os.path.abspath(os.path.dirname(__file__))
path = os.path.join(path, "programming-languages-to-file-extensions.json")
with open(path) as f:
return json.load(f)
def read_valid_extensions():
content = read_languages_to_file_extensions()
extensions = []
for k, exts in content.items():
extensions += exts
return set(extensions)
def read_extension_to_language_mappings():
content = read_languages_to_file_extensions()
mappings = dict()
for k, exts in content.items():
for x in exts:
mappings[x] = k
return mappings
def dataset_iter(files):
def gen():
mappings = read_extension_to_language_mappings()
for x in files:
_, extname = os.path.splitext(x)
with open(x) as f:
content = f.read()
yield dict(
language=mappings[extname],
content=content,
**metrics.compute(content),
)
return gen
if __name__ == "__main__":
valid_extensions = read_valid_extensions()
def is_valid_file(x):
if not os.path.isfile(x):
return False
_, extname = os.path.splitext(x)
if not extname in valid_extensions:
return False
return True
args = parse_args()
files = list(
filter(is_valid_file, glob.glob(args.project_dir + "/**/*", recursive=True))
)
ds = Dataset.from_generator(dataset_iter(files))
ds = ds.filter(filters.basic_filters(args))
ds.save_to_disk(args.output_dir)