tabby/crates/tabby-scheduler/src/index.rs

185 lines
8.5 KiB
Rust
Raw Normal View History

use std::fs;
use anyhow::Result;
use tabby_common::{
config::Config,
index::{IndexExt, CODE_TOKENIZER},
path::index_dir,
SourceFile,
};
use tantivy::{
directory::MmapDirectory,
doc,
schema::{Schema, TextFieldIndexing, TextOptions, STORED, STRING},
2023-06-12 19:21:27 +00:00
Index,
};
// Magic numbers
static MAX_LINE_LENGTH_THRESHOLD: usize = 300;
static AVG_LINE_LENGTH_THRESHOLD: f32 = 150f32;
2023-06-12 19:21:27 +00:00
pub fn index_repositories(_config: &Config) -> Result<()> {
let mut builder = Schema::builder();
let code_indexing_options = TextFieldIndexing::default()
.set_tokenizer(CODE_TOKENIZER)
.set_index_option(tantivy::schema::IndexRecordOption::WithFreqsAndPositions);
let code_options = TextOptions::default()
.set_indexing_options(code_indexing_options)
.set_stored();
let field_git_url = builder.add_text_field("git_url", STRING | STORED);
let field_filepath = builder.add_text_field("filepath", STRING | STORED);
let field_language = builder.add_text_field("language", STRING | STORED);
let field_name = builder.add_text_field("name", STRING | STORED);
let field_kind = builder.add_text_field("kind", STRING | STORED);
let field_body = builder.add_text_field("body", code_options);
2023-06-12 19:21:27 +00:00
let schema = builder.build();
fs::create_dir_all(index_dir())?;
let directory = MmapDirectory::open(index_dir())?;
2023-06-12 19:21:27 +00:00
let index = Index::open_or_create(directory, schema)?;
index.register_tokenizer();
let mut writer = index.writer(10_000_000)?;
writer.delete_all_documents()?;
2023-06-12 19:21:27 +00:00
for file in SourceFile::all()? {
if file.max_line_length > MAX_LINE_LENGTH_THRESHOLD {
continue;
}
if file.avg_line_length > AVG_LINE_LENGTH_THRESHOLD {
continue;
}
for doc in from_source_file(file) {
writer.add_document(doc!(
field_git_url => doc.git_url,
field_filepath => doc.filepath,
field_language => doc.language,
field_name => doc.name,
field_body => doc.body,
field_kind => doc.kind,
))?;
}
}
writer.commit()?;
Ok(())
}
/// Atomic repository document in index.
struct IndexedDocument {
git_url: String,
filepath: String,
language: String,
name: String,
body: String,
kind: String,
}
fn from_source_file(file: SourceFile) -> impl Iterator<Item = IndexedDocument> {
file.tags.into_iter().map(move |tag| {
let name = file.content.get(tag.name_range).unwrap().to_owned();
let body = file.content.get(tag.range).unwrap().to_owned();
let language = reduce_language_if_needed(&file.language).to_owned();
IndexedDocument {
git_url: file.git_url.clone(),
filepath: file.filepath.clone(),
language,
name,
body,
kind: tag.syntax_type_name,
}
})
}
fn reduce_language_if_needed(language: &str) -> &str {
if ["javascript", "jsx", "typescript", "tsx"].contains(&language) {
"javascript-typescript"
} else {
language
}
}
#[cfg(test)]
mod tests {
use serde_json::{from_value, json};
use super::*;
fn test_source_file() -> SourceFile {
from_value(json!(
{
"git_url": "https://fake.com/tabbyml.git",
"filepath": "python/tabby/trainer.py",
"content": "import os\nimport glob\nfrom dataclasses import dataclass, field\nfrom typing import List\n\nimport peft\nimport torch\nfrom transformers import (\n AutoModelForCausalLM,\n AutoTokenizer,\n HfArgumentParser,\n Trainer,\n TrainingArguments,\n)\nfrom datasets import Dataset, load_dataset\n\n\nclass ConstantLengthDataset:\n \"\"\"\n Iterable dataset that returns constant length chunks of tokens from stream of text files.\n Args:\n tokenizer (Tokenizer): The processor used for proccessing the data.\n dataset (dataset.Dataset): Dataset with text files.\n infinite (bool): If True the iterator is reset after dataset reaches end else stops.\n seq_length (int): Length of token sequences to return.\n num_of_sequences (int): Number of token sequences to keep in buffer.\n chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.\n \"\"\"\n\n def __init__(\n self,\n tokenizer,\n dataset,\n infinite=False,\n seq_length=1024,\n num_of_sequences=1024,\n chars_per_token=3.6,\n content_field=\"content\",\n ):\n self.tokenizer = tokenizer\n self.concat_token_id = tokenizer.eos_token_id\n self.dataset = dataset\n self.seq_length = seq_length\n self.infinite = infinite\n self.current_size = 0\n self.max_buffer_size = seq_length * chars_per_token * num_of_sequences\n self.content_field = content_field\n\n def __call__(self):\n def gen():\n for x in self:\n yield x\n\n return gen()\n\n def __iter__(self):\n for buffer in self._read_dataset_into_buffer():\n yield from self._tokenize(buffer)\n\n def _tokenize(self, buffer):\n tokenized_inputs = self.tokenizer(buffer, truncation=False)[\"input_ids\"]\n\n all_token_ids = []\n for tokenized_input in tokenized_inputs:\n all_token_ids.extend(tokenized_input + [self.concat_token_id])\n\n for i in range(0, len(all_token_ids), self.seq_length):\n input_ids = all_token_ids[i : i + self.seq_length]\n\n if len(input_ids) < self.seq_length:\n input_ids = all_token_ids[-self.seq_length :]\n\n if len(input_ids) == self.seq_length:\n self.current_size += 1\n yield dict(input_ids=input_ids, labels=input_ids)\n\n def _read_dataset_into_buffer(self):\n iterator = iter(self.dataset)\n more_examples = True\n while more_examples:\n buffer, buffer_len = [], 0\n while True:\n if buffer_len >= self.max_buffer_size:\n break\n try:\n buffer.append(next(iterator)[self.content_field])\n buffer_len += len(buffer[-1])\n except StopIteration:\n if self.infinite:\n iterator = iter(self.dataset)\n else:\n more_examples = False\n break\n yield buffer\n\n\n",
"language": "python",
"max_line_length": 115,
"avg_line_length": 32.388393,
"alphanum_fraction": 0.6066319,
"tags": [
{
"range": {
"start": 290,
"end": 3094
},
"name_range": {
"start": 296,
"end": 317
},
"line_range": {
"start": 290,
"end": 318
},
"is_definition": true,
"syntax_type_name": "class"
},
{
"range": {
"start": 953,
"end": 1507
},
"name_range": {
"start": 957,
"end": 965
},
"line_range": {
"start": 953,
"end": 966
},
"is_definition": true,
"syntax_type_name": "function"
},
]
})).unwrap()
}
#[test]
fn it_create_documents() {
let source_file: SourceFile = test_source_file();
let docs: Vec<_> = from_source_file(source_file).collect();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].name, "ConstantLengthDataset");
assert_eq!(docs[0].kind, "class");
assert!(
docs[0].body.starts_with("class ConstantLengthDataset"),
"body: {:?}",
docs[0].body
);
assert_eq!(docs[1].name, "__init__");
assert_eq!(docs[1].kind, "function");
assert!(
docs[1].body.starts_with("def __init__"),
"body: {:?}",
docs[1].body
);
}
}