feat: improve logging for build dataset jobs (#41)

* feat: move preprocess to build_dataset

* Improve logging for jobs in update_dataset

* improve logging
add-more-languages
Meng Zhang 2023-04-04 14:15:51 +08:00 committed by GitHub
parent 79585cc2a4
commit 82bcd9b1df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 35 additions and 51 deletions

View File

@ -7,12 +7,12 @@ env:
- GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR" - GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR"
- DATASET_DIR: "$DATASET_DIR" - DATASET_DIR: "$DATASET_DIR"
steps: steps:
- name: Update repositories - name: update repositories
dir: $APP_DIR dir: $APP_DIR
command: python -m tabby.tools.repository.updater --data_dir=$GIT_REPOSITORIES_DIR --config_file=$CONFIG_FILE command: python -m tabby.tools.repository.updater --data_dir=$GIT_REPOSITORIES_DIR --config_file=$CONFIG_FILE
- name: Generate dataset - name: generate dataset
dir: $APP_DIR dir: $APP_DIR
command: python -m tabby.tools.preprocess.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR command: python -m tabby.tools.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
depends: depends:
- Update repositories - update repositories

View File

@ -2,10 +2,11 @@ import glob
import json import json
import os import os
import pandas as pd
from datasets import Dataset from datasets import Dataset
from transformers import HfArgumentParser from transformers import HfArgumentParser
from . import filters, metrics from . import metrics
from .args import PreprocessProjectArgs from .args import PreprocessProjectArgs
@ -61,6 +62,17 @@ def dataset_iter(files):
return gen return gen
def count_by_language(dataset):
key = "language"
df = (
pd.DataFrame(dataset[key], columns=[key])
.groupby([key])
.size()
.to_frame("count")
)
return df
if __name__ == "__main__": if __name__ == "__main__":
valid_extensions = read_valid_extensions() valid_extensions = read_valid_extensions()
@ -80,5 +92,8 @@ if __name__ == "__main__":
) )
ds = Dataset.from_generator(dataset_iter(files)) ds = Dataset.from_generator(dataset_iter(files))
ds = ds.filter(filters.basic_filters(args))
ds.save_to_disk(args.output_dir) ds.save_to_disk(args.output_dir)
print("\n## Summary")
print("Number of source files", len(ds))
print("Number of source files by languages", count_by_language(ds).to_json())

View File

@ -0,0 +1,14 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class PreprocessProjectArgs:
# add arguments in the following format
project_dir: Optional[str] = field(
metadata={"help": "Project directory."},
)
output_dir: Optional[str] = field(
metadata={"help": "Output save path directory."},
)

View File

@ -1,30 +0,0 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass(kw_only=True)
class FilterArgs:
line_max: Optional[int] = field(
default=1000,
metadata={"help": "Max line length allowed"},
)
line_mean: Optional[int] = field(
default=100,
metadata={"help": "Mean line length allowed"},
)
alpha_frac: Optional[float] = field(
default=0.25,
metadata={"help": "Minimum fraction of alphanumeric characters allowed."},
)
@dataclass
class PreprocessProjectArgs(FilterArgs):
# add arguments in the following format
project_dir: Optional[str] = field(
metadata={"help": "Project directory."},
)
output_dir: Optional[str] = field(
metadata={"help": "Output save path directory."},
)

View File

@ -1,15 +0,0 @@
from .args import FilterArgs
def basic_filters(args: FilterArgs):
def fn(example):
"""Filter files based on line length and % alphanumeric characters"""
if example["max_line_length"] > args.line_max:
return False
elif example["avg_line_length"] > args.line_mean:
return False
elif example["alphanum_fraction"] < args.alpha_frac:
return False
return True
return fn