diff --git a/python/setup.py b/python/setup.py index 23b4a35..ae31868 100644 --- a/python/setup.py +++ b/python/setup.py @@ -4,9 +4,10 @@ setup( name="tabby", packages=find_packages(exclude=["tabby_tests"]), install_requires=[ + "datasets", "dagster", "dagster-cloud", - "dagster-pandas" + "dagster-pandas", ], extras_require={"dev": ["dagster-webserver", "pytest"]}, ) diff --git a/python/tabby/assets.py b/python/tabby/assets.py index 05a2a6c..aa0e18f 100644 --- a/python/tabby/assets.py +++ b/python/tabby/assets.py @@ -24,7 +24,7 @@ DatasetDataFrame = create_dagster_pandas_dataframe_type( @asset(dagster_type=DatasetDataFrame) def dataset(): - """Get source code information from TABBY_ROOT""" + """Read source code dataset from TABBY_ROOT""" ds = [] for path in glob.glob(constants.TABBY_DATASET_FILEPATTERN): @@ -50,34 +50,16 @@ def dataset(): } return Output(df, metadata=metadata) -EventDataFrame = create_dagster_pandas_dataframe_type( - name="EventDataFrame", - columns=[ - PandasColumn.integer_column("ts"), - PandasColumn.exists("event"), - ], -) +@asset +def train_dataset(dataset): + """Filter source code dataset for training / evaluation""" + from datasets import Dataset -@asset(dagster_type=EventDataFrame) -def events(): - """Get events information from TABBY_ROOT""" - - ds = [] - for path in glob.glob(constants.TABBY_EVENTS_FILEPATTERN): - with open(path, "r") as f: - for line in f.readlines(): - ds.append(json.loads(line)) - - df = pd.DataFrame(ds) + df = dataset + df = df[df["max_line_length"] < 300] + df = df[df["avg_line_length"] < 150] metadata = { "num_records": len(df), - "preview": MetadataValue.md( - df.head()[ - [ - "ts", - "event" - ] - ].to_markdown() - ), + "num_filtered_records": len(dataset) - len(df) } - return Output(df, metadata=metadata) + return Output(Dataset.from_pandas(df), metadata=metadata) \ No newline at end of file