test: unit test for indexing job (#508)
* test: unit test for indexing job * update * reduce test fixture lengthwsxiaoys-patch-1
parent
e0b2a775d8
commit
55f68d4224
|
|
@ -3187,6 +3187,7 @@ dependencies = [
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"serde",
|
"serde",
|
||||||
"serde-jsonlines",
|
"serde-jsonlines",
|
||||||
|
"serde_json",
|
||||||
"tabby-common",
|
"tabby-common",
|
||||||
"tantivy",
|
"tantivy",
|
||||||
"temp_testdir",
|
"temp_testdir",
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
|
||||||
use serde_jsonlines::JsonLinesReader;
|
use serde_jsonlines::JsonLinesReader;
|
||||||
|
|
||||||
#[derive(Serialize, Deserialize)]
|
#[derive(Serialize, Deserialize)]
|
||||||
pub struct Document {
|
pub struct SourceFile {
|
||||||
pub git_url: String,
|
pub git_url: String,
|
||||||
pub filepath: String,
|
pub filepath: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
|
@ -25,13 +25,13 @@ pub struct Document {
|
||||||
pub tags: Vec<Tag>,
|
pub tags: Vec<Tag>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Document {
|
impl SourceFile {
|
||||||
pub fn all() -> Result<impl Iterator<Item = Self>, Error> {
|
pub fn all() -> Result<impl Iterator<Item = Self>, Error> {
|
||||||
let iter = dataset_dir().read_dir()?.flat_map(|path| {
|
let iter = dataset_dir().read_dir()?.flat_map(|path| {
|
||||||
let path = path.unwrap().path();
|
let path = path.unwrap().path();
|
||||||
let fp = BufReader::new(File::open(path).unwrap());
|
let fp = BufReader::new(File::open(path).unwrap());
|
||||||
let reader = JsonLinesReader::new(fp);
|
let reader = JsonLinesReader::new(fp);
|
||||||
reader.read_all::<Document>().map(|x| x.unwrap())
|
reader.read_all::<SourceFile>().map(|x| x.unwrap())
|
||||||
});
|
});
|
||||||
Ok(iter)
|
Ok(iter)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,3 +31,4 @@ temp_testdir = "0.2"
|
||||||
tabby-common = { path = "../tabby-common", features = [ "testutils" ] }
|
tabby-common = { path = "../tabby-common", features = [ "testutils" ] }
|
||||||
tracing-test = "0.1"
|
tracing-test = "0.1"
|
||||||
tokio = { workspace = true, features = ["rt"] }
|
tokio = { workspace = true, features = ["rt"] }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ use serde_jsonlines::WriteExt;
|
||||||
use tabby_common::{
|
use tabby_common::{
|
||||||
config::{Config, Repository},
|
config::{Config, Repository},
|
||||||
path::dataset_dir,
|
path::dataset_dir,
|
||||||
Document,
|
SourceFile,
|
||||||
};
|
};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use tree_sitter_tags::{TagsConfiguration, TagsContext};
|
use tree_sitter_tags::{TagsConfiguration, TagsContext};
|
||||||
|
|
@ -41,7 +41,7 @@ impl RepositoryExt for Repository {
|
||||||
.to_owned();
|
.to_owned();
|
||||||
if let Ok(file_content) = read_to_string(entry.path()) {
|
if let Ok(file_content) = read_to_string(entry.path()) {
|
||||||
info!("Building {:?}", relative_path);
|
info!("Building {:?}", relative_path);
|
||||||
let doc = Document {
|
let source_file = SourceFile {
|
||||||
git_url: self.git_url.clone(),
|
git_url: self.git_url.clone(),
|
||||||
filepath: relative_path.display().to_string(),
|
filepath: relative_path.display().to_string(),
|
||||||
max_line_length: metrics::max_line_length(&file_content),
|
max_line_length: metrics::max_line_length(&file_content),
|
||||||
|
|
@ -51,7 +51,7 @@ impl RepositoryExt for Repository {
|
||||||
language,
|
language,
|
||||||
content: file_content,
|
content: file_content,
|
||||||
};
|
};
|
||||||
writer.write_json_lines([doc])?;
|
writer.write_json_lines([source_file])?;
|
||||||
} else {
|
} else {
|
||||||
error!("Cannot read {:?}", relative_path);
|
error!("Cannot read {:?}", relative_path);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
use std::{collections::HashMap, fs};
|
use std::fs;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use lazy_static::lazy_static;
|
use tabby_common::{config::Config, path::index_dir, SourceFile};
|
||||||
use tabby_common::{config::Config, path::index_dir, Document};
|
|
||||||
use tantivy::{
|
use tantivy::{
|
||||||
directory::MmapDirectory,
|
directory::MmapDirectory,
|
||||||
doc,
|
doc,
|
||||||
|
|
@ -28,33 +27,15 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
|
||||||
let mut writer = index.writer(10_000_000)?;
|
let mut writer = index.writer(10_000_000)?;
|
||||||
writer.delete_all_documents()?;
|
writer.delete_all_documents()?;
|
||||||
|
|
||||||
for doc in Document::all()? {
|
for file in SourceFile::all()? {
|
||||||
for tag in doc.tags {
|
for doc in from_source_file(file) {
|
||||||
let name = doc.content.get(tag.name_range).unwrap();
|
|
||||||
if name.len() < 5 {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let body = doc.content.get(tag.range).unwrap();
|
|
||||||
let count_body_lines = body.lines().count();
|
|
||||||
if !(3..=10).contains(&count_body_lines) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(blacklist) = LANGUAGE_NAME_BLACKLIST.get(doc.language.as_str()) {
|
|
||||||
if blacklist.contains(&name) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
let language = reduce_language_if_needed(&doc.language);
|
|
||||||
writer.add_document(doc!(
|
writer.add_document(doc!(
|
||||||
field_git_url => doc.git_url.clone(),
|
field_git_url => doc.git_url,
|
||||||
field_filepath => doc.filepath.clone(),
|
field_filepath => doc.filepath,
|
||||||
field_language => language,
|
field_language => doc.language,
|
||||||
field_name => name,
|
field_name => doc.name,
|
||||||
field_body => body,
|
field_body => doc.body,
|
||||||
field_kind => tag.syntax_type_name,
|
field_kind => doc.kind,
|
||||||
))?;
|
))?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -64,6 +45,33 @@ pub fn index_repositories(_config: &Config) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Atomic repository document in index.
|
||||||
|
struct IndexedDocument {
|
||||||
|
git_url: String,
|
||||||
|
filepath: String,
|
||||||
|
language: String,
|
||||||
|
name: String,
|
||||||
|
body: String,
|
||||||
|
kind: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> {
|
||||||
|
file.tags.into_iter().map(move |tag| {
|
||||||
|
let name = file.content.get(tag.name_range).unwrap().to_owned();
|
||||||
|
let body = file.content.get(tag.range).unwrap().to_owned();
|
||||||
|
|
||||||
|
let language = reduce_language_if_needed(&file.language).to_owned();
|
||||||
|
IndexedDocument {
|
||||||
|
git_url: file.git_url.clone(),
|
||||||
|
filepath: file.filepath.clone(),
|
||||||
|
language,
|
||||||
|
name,
|
||||||
|
body,
|
||||||
|
kind: tag.syntax_type_name,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn reduce_language_if_needed(language: &str) -> &str {
|
fn reduce_language_if_needed(language: &str) -> &str {
|
||||||
if ["javascript", "jsx", "typescript", "tsx"].contains(&language) {
|
if ["javascript", "jsx", "typescript", "tsx"].contains(&language) {
|
||||||
"javascript-typescript"
|
"javascript-typescript"
|
||||||
|
|
@ -72,7 +80,69 @@ fn reduce_language_if_needed(language: &str) -> &str {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lazy_static! {
|
#[cfg(test)]
|
||||||
static ref LANGUAGE_NAME_BLACKLIST: HashMap<&'static str, Vec<&'static str>> =
|
mod tests {
|
||||||
HashMap::from([("python", vec!["__init__"])]);
|
use serde_json::{from_value, json};
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn test_source_file() -> SourceFile {
|
||||||
|
from_value(json!(
|
||||||
|
{
|
||||||
|
"git_url": "https://fake.com/tabbyml.git",
|
||||||
|
"filepath": "python/tabby/trainer.py",
|
||||||
|
"content": "import os\nimport glob\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nimport peft\nimport torch\nfrom transformers import (\n AutoModelForCausalLM,\n AutoTokenizer,\n HfArgumentParser,\n Trainer,\n TrainingArguments,\n)\nfrom datasets import Dataset, load_dataset\n\n\nclass ConstantLengthDataset:\n \"\"\"\n Iterable dataset that returns constant length chunks of tokens from stream of text files.\n Args:\n tokenizer (Tokenizer): The processor used for proccessing the data.\n dataset (dataset.Dataset): Dataset with text files.\n infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n seq_length (int): Length of token sequences to return.\n num_of_sequences (int): Number of token sequences to keep in buffer.\n chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n \"\"\"\n\n def __init__(\n self,\n tokenizer,\n dataset,\n infinite=False,\n seq_length=1024,\n num_of_sequences=1024,\n chars_per_token=3.6,\n content_field=\"content\",\n ):\n self.tokenizer = tokenizer\n self.concat_token_id = tokenizer.eos_token_id\n self.dataset = dataset\n self.seq_length = seq_length\n self.infinite = infinite\n self.current_size = 0\n self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n self.content_field = content_field\n\n def __call__(self):\n def gen():\n for x in self:\n yield x\n\n return gen()\n\n def __iter__(self):\n for buffer in self._read_dataset_into_buffer():\n yield from self._tokenize(buffer)\n\n def _tokenize(self, buffer):\n tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n\n all_token_ids = []\n for tokenized_input in tokenized_inputs:\n all_token_ids.extend(tokenized_input + [self.concat_token_id])\n\n for i in range(0, len(all_token_ids), self.seq_length):\n input_ids = all_token_ids[i : i + self.seq_length]\n\n if len(input_ids) < self.seq_length:\n input_ids = all_token_ids[-self.seq_length :]\n\n if len(input_ids) == self.seq_length:\n self.current_size += 1\n yield dict(input_ids=input_ids, labels=input_ids)\n\n def _read_dataset_into_buffer(self):\n iterator = iter(self.dataset)\n more_examples = True\n while more_examples:\n buffer, buffer_len = [], 0\n while True:\n if buffer_len >= self.max_buffer_size:\n break\n try:\n buffer.append(next(iterator)[self.content_field])\n buffer_len += len(buffer[-1])\n except StopIteration:\n if self.infinite:\n iterator = iter(self.dataset)\n else:\n more_examples = False\n break\n yield buffer\n\n\n",
|
||||||
|
"language": "python",
|
||||||
|
"max_line_length": 115,
|
||||||
|
"avg_line_length": 32.388393,
|
||||||
|
"alphanum_fraction": 0.6066319,
|
||||||
|
"tags": [
|
||||||
|
{
|
||||||
|
"range": {
|
||||||
|
"start": 290,
|
||||||
|
"end": 3094
|
||||||
|
},
|
||||||
|
"name_range": {
|
||||||
|
"start": 296,
|
||||||
|
"end": 317
|
||||||
|
},
|
||||||
|
"line_range": {
|
||||||
|
"start": 290,
|
||||||
|
"end": 318
|
||||||
|
},
|
||||||
|
"is_definition": true,
|
||||||
|
"syntax_type_name": "class"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"range": {
|
||||||
|
"start": 953,
|
||||||
|
"end": 1507
|
||||||
|
},
|
||||||
|
"name_range": {
|
||||||
|
"start": 957,
|
||||||
|
"end": 965
|
||||||
|
},
|
||||||
|
"line_range": {
|
||||||
|
"start": 953,
|
||||||
|
"end": 966
|
||||||
|
},
|
||||||
|
"is_definition": true,
|
||||||
|
"syntax_type_name": "function"
|
||||||
|
},
|
||||||
|
]
|
||||||
|
})).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn it_create_documents() {
|
||||||
|
let source_file: SourceFile = test_source_file();
|
||||||
|
let docs: Vec<_> = from_source_file(source_file).collect();
|
||||||
|
assert_eq!(docs.len(), 2);
|
||||||
|
|
||||||
|
assert_eq!(docs[0].name, "ConstantLengthDataset");
|
||||||
|
assert_eq!(docs[0].kind, "class");
|
||||||
|
|
||||||
|
assert_eq!(docs[1].name, "__init__");
|
||||||
|
assert_eq!(docs[1].kind, "function");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue