diff --git a/experimental/eval/main.py b/experimental/eval/main.py index d085335..dd15f61 100644 --- a/experimental/eval/main.py +++ b/experimental/eval/main.py @@ -36,12 +36,13 @@ def run_eval(args): except: print(f"Tabby Server is not ready, please check if '{api}' is correct.") return - - items = [x for x in processing.items_from_filepattern(args.filepattern) if valid_item(x)]; + + items = [ + x for x in processing.items_from_filepattern(args.filepattern) if valid_item(x) + ] if len(items) > args.max_records: - random.seed(0xbadbeef) + random.seed(0xBADBEEF) items = random.sample(items, args.max_records) - for item in items: if not valid_item(item): @@ -56,10 +57,10 @@ def run_eval(args): prediction = resp.choices[0].text block_score = scorer(label, prediction) - + label_lines = label.splitlines() prediction_lines = prediction.splitlines() - + if len(label_lines) > 0 and len(prediction_lines) > 0: line_score = scorer(label_lines[0], prediction_lines[0]) @@ -71,13 +72,19 @@ def run_eval(args): line_score=line_score, ) + if __name__ == "__main__": logging.basicConfig(stream=sys.stderr, level=logging.INFO) - parser = argparse.ArgumentParser(description='SxS eval for tabby') - parser.add_argument('filepattern', type=str, help='File pattern to dataset.') - parser.add_argument('max_records', type=int, help='Max number of records to be evaluated.') + parser = argparse.ArgumentParser( + description="SxS eval for tabby", + epilog="Example usage: python main.py ./tabby/dataset/data.jsonl 5 > output.jsonl", + ) + parser.add_argument("filepattern", type=str, help="File pattern to dataset.") + parser.add_argument( + "max_records", type=int, help="Max number of records to be evaluated." + ) args = parser.parse_args() logging.info("args %s", args) df = pd.DataFrame(run_eval(args)) - print(df.to_json(orient='records', lines=True)) + print(df.to_json(orient="records", lines=True))