add-dagster-data-pipeline
Meng Zhang 2023-10-17 16:54:15 -07:00
parent 1b52a83dcc
commit 4d1d8965e0
2 changed files with 12 additions and 29 deletions

View File

@ -4,9 +4,10 @@ setup(
name="tabby",
packages=find_packages(exclude=["tabby_tests"]),
install_requires=[
"datasets",
"dagster",
"dagster-cloud",
"dagster-pandas"
"dagster-pandas",
],
extras_require={"dev": ["dagster-webserver", "pytest"]},
)

View File

@ -24,7 +24,7 @@ DatasetDataFrame = create_dagster_pandas_dataframe_type(
@asset(dagster_type=DatasetDataFrame)
def dataset():
"""Get source code information from TABBY_ROOT"""
"""Read source code dataset from TABBY_ROOT"""
ds = []
for path in glob.glob(constants.TABBY_DATASET_FILEPATTERN):
@ -50,34 +50,16 @@ def dataset():
}
return Output(df, metadata=metadata)
EventDataFrame = create_dagster_pandas_dataframe_type(
name="EventDataFrame",
columns=[
PandasColumn.integer_column("ts"),
PandasColumn.exists("event"),
],
)
@asset
def train_dataset(dataset):
"""Filter source code dataset for training / evaluation"""
from datasets import Dataset
@asset(dagster_type=EventDataFrame)
def events():
"""Get events information from TABBY_ROOT"""
ds = []
for path in glob.glob(constants.TABBY_EVENTS_FILEPATTERN):
with open(path, "r") as f:
for line in f.readlines():
ds.append(json.loads(line))
df = pd.DataFrame(ds)
df = dataset
df = df[df["max_line_length"] < 300]
df = df[df["avg_line_length"] < 150]
metadata = {
"num_records": len(df),
"preview": MetadataValue.md(
df.head()[
[
"ts",
"event"
]
].to_markdown()
),
"num_filtered_records": len(dataset) - len(df)
}
return Output(df, metadata=metadata)
return Output(Dataset.from_pandas(df), metadata=metadata)