diff --git a/preprocess/args.py b/preprocess/args.py index 99c4af3..f55b113 100644 --- a/preprocess/args.py +++ b/preprocess/args.py @@ -2,8 +2,24 @@ from dataclasses import dataclass, field from typing import Optional +@dataclass(kw_only=True) +class FilterArgs: + line_max: Optional[int] = field( + default=1000, + metadata={"help": "Max line length allowed"}, + ) + line_mean: Optional[int] = field( + default=100, + metadata={"help": "Mean line length allowed"}, + ) + alpha_frac: Optional[float] = field( + default=0.25, + metadata={"help": "Minimum fraction of alphanumeric characters allowed."}, + ) + + @dataclass -class PreprocessProjectArgs: +class PreprocessProjectArgs(FilterArgs): # add arguments in the following format project_dir: Optional[str] = field( metadata={"help": "Project directory."}, diff --git a/preprocess/filters.py b/preprocess/filters.py new file mode 100644 index 0000000..5506239 --- /dev/null +++ b/preprocess/filters.py @@ -0,0 +1,15 @@ +from args import FilterArgs + + +def basic_filters(args: FilterArgs): + def fn(example): + """Filter files based on line length and % alphanumeric characters""" + if example["max_line_length"] > args.line_max: + return False + elif example["avg_line_length"] > args.line_mean: + return False + elif example["alphanum_fraction"] < args.alpha_frac: + return False + return True + + return fn diff --git a/preprocess/preprocess_project.py b/preprocess/preprocess_project.py index 992462a..d9c6580 100644 --- a/preprocess/preprocess_project.py +++ b/preprocess/preprocess_project.py @@ -2,6 +2,7 @@ import glob import json import os +import filters import metrics from args import PreprocessProjectArgs from datasets import Dataset @@ -74,4 +75,5 @@ if __name__ == "__main__": ) ds = Dataset.from_generator(dataset_iter(files)) + ds = ds.filter(filters.basic_filters(args)) ds.save_to_disk(args.output_dir)