From 55f68d422444db0514c72060ba1b59783689d75b Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Thu, 5 Oct 2023 07:27:19 +0800 Subject: [PATCH] test: unit test for indexing job (#508) * test: unit test for indexing job * update * reduce test fixture length --- Cargo.lock | 1 + crates/tabby-common/src/lib.rs | 6 +- crates/tabby-scheduler/Cargo.toml | 1 + crates/tabby-scheduler/src/dataset.rs | 6 +- crates/tabby-scheduler/src/index.rs | 134 ++++++++++++++++++++------ 5 files changed, 110 insertions(+), 38 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e1da931..994e53e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3187,6 +3187,7 @@ dependencies = [ "lazy_static", "serde", "serde-jsonlines", + "serde_json", "tabby-common", "tantivy", "temp_testdir", diff --git a/crates/tabby-common/src/lib.rs b/crates/tabby-common/src/lib.rs index 44483f2..b338145 100644 --- a/crates/tabby-common/src/lib.rs +++ b/crates/tabby-common/src/lib.rs @@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize}; use serde_jsonlines::JsonLinesReader; #[derive(Serialize, Deserialize)] -pub struct Document { +pub struct SourceFile { pub git_url: String, pub filepath: String, pub content: String, @@ -25,13 +25,13 @@ pub struct Document { pub tags: Vec, } -impl Document { +impl SourceFile { pub fn all() -> Result, Error> { let iter = dataset_dir().read_dir()?.flat_map(|path| { let path = path.unwrap().path(); let fp = BufReader::new(File::open(path).unwrap()); let reader = JsonLinesReader::new(fp); - reader.read_all::().map(|x| x.unwrap()) + reader.read_all::().map(|x| x.unwrap()) }); Ok(iter) } diff --git a/crates/tabby-scheduler/Cargo.toml b/crates/tabby-scheduler/Cargo.toml index 4d1d12e..25988ea 100644 --- a/crates/tabby-scheduler/Cargo.toml +++ b/crates/tabby-scheduler/Cargo.toml @@ -31,3 +31,4 @@ temp_testdir = "0.2" tabby-common = { path = "../tabby-common", features = [ "testutils" ] } tracing-test = "0.1" tokio = { workspace = true, features = ["rt"] } +serde_json = { workspace = true } diff --git a/crates/tabby-scheduler/src/dataset.rs b/crates/tabby-scheduler/src/dataset.rs index e467f16..4459ca3 100644 --- a/crates/tabby-scheduler/src/dataset.rs +++ b/crates/tabby-scheduler/src/dataset.rs @@ -12,7 +12,7 @@ use serde_jsonlines::WriteExt; use tabby_common::{ config::{Config, Repository}, path::dataset_dir, - Document, + SourceFile, }; use tracing::{error, info}; use tree_sitter_tags::{TagsConfiguration, TagsContext}; @@ -41,7 +41,7 @@ impl RepositoryExt for Repository { .to_owned(); if let Ok(file_content) = read_to_string(entry.path()) { info!("Building {:?}", relative_path); - let doc = Document { + let source_file = SourceFile { git_url: self.git_url.clone(), filepath: relative_path.display().to_string(), max_line_length: metrics::max_line_length(&file_content), @@ -51,7 +51,7 @@ impl RepositoryExt for Repository { language, content: file_content, }; - writer.write_json_lines([doc])?; + writer.write_json_lines([source_file])?; } else { error!("Cannot read {:?}", relative_path); } diff --git a/crates/tabby-scheduler/src/index.rs b/crates/tabby-scheduler/src/index.rs index 418782f..8477760 100644 --- a/crates/tabby-scheduler/src/index.rs +++ b/crates/tabby-scheduler/src/index.rs @@ -1,8 +1,7 @@ -use std::{collections::HashMap, fs}; +use std::fs; use anyhow::Result; -use lazy_static::lazy_static; -use tabby_common::{config::Config, path::index_dir, Document}; +use tabby_common::{config::Config, path::index_dir, SourceFile}; use tantivy::{ directory::MmapDirectory, doc, @@ -28,33 +27,15 @@ pub fn index_repositories(_config: &Config) -> Result<()> { let mut writer = index.writer(10_000_000)?; writer.delete_all_documents()?; - for doc in Document::all()? { - for tag in doc.tags { - let name = doc.content.get(tag.name_range).unwrap(); - if name.len() < 5 { - continue; - } - - let body = doc.content.get(tag.range).unwrap(); - let count_body_lines = body.lines().count(); - if !(3..=10).contains(&count_body_lines) { - continue; - } - - if let Some(blacklist) = LANGUAGE_NAME_BLACKLIST.get(doc.language.as_str()) { - if blacklist.contains(&name) { - continue; - } - } - - let language = reduce_language_if_needed(&doc.language); + for file in SourceFile::all()? { + for doc in from_source_file(file) { writer.add_document(doc!( - field_git_url => doc.git_url.clone(), - field_filepath => doc.filepath.clone(), - field_language => language, - field_name => name, - field_body => body, - field_kind => tag.syntax_type_name, + field_git_url => doc.git_url, + field_filepath => doc.filepath, + field_language => doc.language, + field_name => doc.name, + field_body => doc.body, + field_kind => doc.kind, ))?; } } @@ -64,6 +45,33 @@ pub fn index_repositories(_config: &Config) -> Result<()> { Ok(()) } +/// Atomic repository document in index. +struct IndexedDocument { + git_url: String, + filepath: String, + language: String, + name: String, + body: String, + kind: String, +} + +fn from_source_file(file: SourceFile) -> impl Iterator { + file.tags.into_iter().map(move |tag| { + let name = file.content.get(tag.name_range).unwrap().to_owned(); + let body = file.content.get(tag.range).unwrap().to_owned(); + + let language = reduce_language_if_needed(&file.language).to_owned(); + IndexedDocument { + git_url: file.git_url.clone(), + filepath: file.filepath.clone(), + language, + name, + body, + kind: tag.syntax_type_name, + } + }) +} + fn reduce_language_if_needed(language: &str) -> &str { if ["javascript", "jsx", "typescript", "tsx"].contains(&language) { "javascript-typescript" @@ -72,7 +80,69 @@ fn reduce_language_if_needed(language: &str) -> &str { } } -lazy_static! { - static ref LANGUAGE_NAME_BLACKLIST: HashMap<&'static str, Vec<&'static str>> = - HashMap::from([("python", vec!["__init__"])]); +#[cfg(test)] +mod tests { + use serde_json::{from_value, json}; + + use super::*; + + fn test_source_file() -> SourceFile { + from_value(json!( + { + "git_url": "https://fake.com/tabbyml.git", + "filepath": "python/tabby/trainer.py", + "content": "import os\nimport glob\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nimport peft\nimport torch\nfrom transformers import (\n AutoModelForCausalLM,\n AutoTokenizer,\n HfArgumentParser,\n Trainer,\n TrainingArguments,\n)\nfrom datasets import Dataset, load_dataset\n\n\nclass ConstantLengthDataset:\n \"\"\"\n Iterable dataset that returns constant length chunks of tokens from stream of text files.\n Args:\n tokenizer (Tokenizer): The processor used for proccessing the data.\n dataset (dataset.Dataset): Dataset with text files.\n infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n seq_length (int): Length of token sequences to return.\n num_of_sequences (int): Number of token sequences to keep in buffer.\n chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n \"\"\"\n\n def __init__(\n self,\n tokenizer,\n dataset,\n infinite=False,\n seq_length=1024,\n num_of_sequences=1024,\n chars_per_token=3.6,\n content_field=\"content\",\n ):\n self.tokenizer = tokenizer\n self.concat_token_id = tokenizer.eos_token_id\n self.dataset = dataset\n self.seq_length = seq_length\n self.infinite = infinite\n self.current_size = 0\n self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n self.content_field = content_field\n\n def __call__(self):\n def gen():\n for x in self:\n yield x\n\n return gen()\n\n def __iter__(self):\n for buffer in self._read_dataset_into_buffer():\n yield from self._tokenize(buffer)\n\n def _tokenize(self, buffer):\n tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n\n all_token_ids = []\n for tokenized_input in tokenized_inputs:\n all_token_ids.extend(tokenized_input + [self.concat_token_id])\n\n for i in range(0, len(all_token_ids), self.seq_length):\n input_ids = all_token_ids[i : i + self.seq_length]\n\n if len(input_ids) < self.seq_length:\n input_ids = all_token_ids[-self.seq_length :]\n\n if len(input_ids) == self.seq_length:\n self.current_size += 1\n yield dict(input_ids=input_ids, labels=input_ids)\n\n def _read_dataset_into_buffer(self):\n iterator = iter(self.dataset)\n more_examples = True\n while more_examples:\n buffer, buffer_len = [], 0\n while True:\n if buffer_len >= self.max_buffer_size:\n break\n try:\n buffer.append(next(iterator)[self.content_field])\n buffer_len += len(buffer[-1])\n except StopIteration:\n if self.infinite:\n iterator = iter(self.dataset)\n else:\n more_examples = False\n break\n yield buffer\n\n\n", + "language": "python", + "max_line_length": 115, + "avg_line_length": 32.388393, + "alphanum_fraction": 0.6066319, + "tags": [ + { + "range": { + "start": 290, + "end": 3094 + }, + "name_range": { + "start": 296, + "end": 317 + }, + "line_range": { + "start": 290, + "end": 318 + }, + "is_definition": true, + "syntax_type_name": "class" + }, + { + "range": { + "start": 953, + "end": 1507 + }, + "name_range": { + "start": 957, + "end": 965 + }, + "line_range": { + "start": 953, + "end": 966 + }, + "is_definition": true, + "syntax_type_name": "function" + }, + ] + })).unwrap() + } + + #[test] + fn it_create_documents() { + let source_file: SourceFile = test_source_file(); + let docs: Vec<_> = from_source_file(source_file).collect(); + assert_eq!(docs.len(), 2); + + assert_eq!(docs[0].name, "ConstantLengthDataset"); + assert_eq!(docs[0].kind, "class"); + + assert_eq!(docs[1].name, "__init__"); + assert_eq!(docs[1].kind, "function"); + } }