feat: improve logging for build dataset jobs (#41)

* feat: move preprocess to build_dataset

* Improve logging for jobs in update_dataset

* improve logging
add-more-languages
Meng Zhang 2023-04-04 14:15:51 +08:00 committed by GitHub
parent 79585cc2a4
commit 82bcd9b1df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 35 additions and 51 deletions

View File

@ -7,12 +7,12 @@ env:
- GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR"
- DATASET_DIR: "$DATASET_DIR"
steps:
- name: Update repositories
- name: update repositories
dir: $APP_DIR
command: python -m tabby.tools.repository.updater --data_dir=$GIT_REPOSITORIES_DIR --config_file=$CONFIG_FILE
- name: Generate dataset
- name: generate dataset
dir: $APP_DIR
command: python -m tabby.tools.preprocess.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
command: python -m tabby.tools.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
depends:
- Update repositories
- update repositories

View File

@ -2,10 +2,11 @@ import glob
import json
import os
import pandas as pd
from datasets import Dataset
from transformers import HfArgumentParser
from . import filters, metrics
from . import metrics
from .args import PreprocessProjectArgs
@ -61,6 +62,17 @@ def dataset_iter(files):
return gen
def count_by_language(dataset):
key = "language"
df = (
pd.DataFrame(dataset[key], columns=[key])
.groupby([key])
.size()
.to_frame("count")
)
return df
if __name__ == "__main__":
valid_extensions = read_valid_extensions()
@ -80,5 +92,8 @@ if __name__ == "__main__":
)
ds = Dataset.from_generator(dataset_iter(files))
ds = ds.filter(filters.basic_filters(args))
ds.save_to_disk(args.output_dir)
print("\n## Summary")
print("Number of source files", len(ds))
print("Number of source files by languages", count_by_language(ds).to_json())

View File

@ -0,0 +1,14 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class PreprocessProjectArgs:
# add arguments in the following format
project_dir: Optional[str] = field(
metadata={"help": "Project directory."},
)
output_dir: Optional[str] = field(
metadata={"help": "Output save path directory."},
)

View File

@ -1,30 +0,0 @@
from dataclasses import dataclass, field
from typing import Optional
@dataclass(kw_only=True)
class FilterArgs:
line_max: Optional[int] = field(
default=1000,
metadata={"help": "Max line length allowed"},
)
line_mean: Optional[int] = field(
default=100,
metadata={"help": "Mean line length allowed"},
)
alpha_frac: Optional[float] = field(
default=0.25,
metadata={"help": "Minimum fraction of alphanumeric characters allowed."},
)
@dataclass
class PreprocessProjectArgs(FilterArgs):
# add arguments in the following format
project_dir: Optional[str] = field(
metadata={"help": "Project directory."},
)
output_dir: Optional[str] = field(
metadata={"help": "Output save path directory."},
)

View File

@ -1,15 +0,0 @@
from .args import FilterArgs
def basic_filters(args: FilterArgs):
def fn(example):
"""Filter files based on line length and % alphanumeric characters"""
if example["max_line_length"] > args.line_max:
return False
elif example["avg_line_length"] > args.line_mean:
return False
elif example["alphanum_fraction"] < args.alpha_frac:
return False
return True
return fn