Add basic filters
parent
e24155fb3a
commit
d32f933aa9
|
|
@ -2,8 +2,24 @@ from dataclasses import dataclass, field
|
|||
from typing import Optional
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class FilterArgs:
|
||||
line_max: Optional[int] = field(
|
||||
default=1000,
|
||||
metadata={"help": "Max line length allowed"},
|
||||
)
|
||||
line_mean: Optional[int] = field(
|
||||
default=100,
|
||||
metadata={"help": "Mean line length allowed"},
|
||||
)
|
||||
alpha_frac: Optional[float] = field(
|
||||
default=0.25,
|
||||
metadata={"help": "Minimum fraction of alphanumeric characters allowed."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessProjectArgs:
|
||||
class PreprocessProjectArgs(FilterArgs):
|
||||
# add arguments in the following format
|
||||
project_dir: Optional[str] = field(
|
||||
metadata={"help": "Project directory."},
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
from args import FilterArgs
|
||||
|
||||
|
||||
def basic_filters(args: FilterArgs):
|
||||
def fn(example):
|
||||
"""Filter files based on line length and % alphanumeric characters"""
|
||||
if example["max_line_length"] > args.line_max:
|
||||
return False
|
||||
elif example["avg_line_length"] > args.line_mean:
|
||||
return False
|
||||
elif example["alphanum_fraction"] < args.alpha_frac:
|
||||
return False
|
||||
return True
|
||||
|
||||
return fn
|
||||
|
|
@ -2,6 +2,7 @@ import glob
|
|||
import json
|
||||
import os
|
||||
|
||||
import filters
|
||||
import metrics
|
||||
from args import PreprocessProjectArgs
|
||||
from datasets import Dataset
|
||||
|
|
@ -74,4 +75,5 @@ if __name__ == "__main__":
|
|||
)
|
||||
|
||||
ds = Dataset.from_generator(dataset_iter(files))
|
||||
ds = ds.filter(filters.basic_filters(args))
|
||||
ds.save_to_disk(args.output_dir)
|
||||
|
|
|
|||
Loading…
Reference in New Issue