Add basic filters

add-more-languages
Meng Zhang 2023-03-16 18:08:55 +08:00
parent e24155fb3a
commit d32f933aa9
3 changed files with 34 additions and 1 deletions

View File

@ -2,8 +2,24 @@ from dataclasses import dataclass, field
from typing import Optional from typing import Optional
@dataclass(kw_only=True)
class FilterArgs:
line_max: Optional[int] = field(
default=1000,
metadata={"help": "Max line length allowed"},
)
line_mean: Optional[int] = field(
default=100,
metadata={"help": "Mean line length allowed"},
)
alpha_frac: Optional[float] = field(
default=0.25,
metadata={"help": "Minimum fraction of alphanumeric characters allowed."},
)
@dataclass @dataclass
class PreprocessProjectArgs: class PreprocessProjectArgs(FilterArgs):
# add arguments in the following format # add arguments in the following format
project_dir: Optional[str] = field( project_dir: Optional[str] = field(
metadata={"help": "Project directory."}, metadata={"help": "Project directory."},

15
preprocess/filters.py Normal file
View File

@ -0,0 +1,15 @@
from args import FilterArgs
def basic_filters(args: FilterArgs):
def fn(example):
"""Filter files based on line length and % alphanumeric characters"""
if example["max_line_length"] > args.line_max:
return False
elif example["avg_line_length"] > args.line_mean:
return False
elif example["alphanum_fraction"] < args.alpha_frac:
return False
return True
return fn

View File

@ -2,6 +2,7 @@ import glob
import json import json
import os import os
import filters
import metrics import metrics
from args import PreprocessProjectArgs from args import PreprocessProjectArgs
from datasets import Dataset from datasets import Dataset
@ -74,4 +75,5 @@ if __name__ == "__main__":
) )
ds = Dataset.from_generator(dataset_iter(files)) ds = Dataset.from_generator(dataset_iter(files))
ds = ds.filter(filters.basic_filters(args))
ds.save_to_disk(args.output_dir) ds.save_to_disk(args.output_dir)