From df67b13639dd21a290dea522158ac99ddbef97da Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Tue, 13 Jun 2023 12:48:01 -0700 Subject: [PATCH] feat: cleanup trainer with new data format --- .../{trainer/train_lora.py => trainer.py} | 105 ++++++++++++++++-- python/tabby/trainer/dataset.py | 87 --------------- 2 files changed, 96 insertions(+), 96 deletions(-) rename python/tabby/{trainer/train_lora.py => trainer.py} (53%) delete mode 100644 python/tabby/trainer/dataset.py diff --git a/python/tabby/trainer/train_lora.py b/python/tabby/trainer.py similarity index 53% rename from python/tabby/trainer/train_lora.py rename to python/tabby/trainer.py index 923c9c2..210373b 100644 --- a/python/tabby/trainer/train_lora.py +++ b/python/tabby/trainer.py @@ -1,14 +1,97 @@ import os +import glob from dataclasses import dataclass, field from typing import List import peft import torch -import torch.nn as nn -import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + HfArgumentParser, + Trainer, + TrainingArguments, +) +from datasets import Dataset, load_dataset -from .dataset import load_dataset + +class ConstantLengthDataset: + """ + Iterable dataset that returns constant length chunks of tokens from stream of text files. + Args: + tokenizer (Tokenizer): The processor used for proccessing the data. + dataset (dataset.Dataset): Dataset with text files. + infinite (bool): If True the iterator is reset after dataset reaches end else stops. + seq_length (int): Length of token sequences to return. + num_of_sequences (int): Number of token sequences to keep in buffer. + chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. + """ + + def __init__( + self, + tokenizer, + dataset, + infinite=False, + seq_length=1024, + num_of_sequences=1024, + chars_per_token=3.6, + content_field="content", + ): + self.tokenizer = tokenizer + self.concat_token_id = tokenizer.eos_token_id + self.dataset = dataset + self.seq_length = seq_length + self.infinite = infinite + self.current_size = 0 + self.max_buffer_size = seq_length * chars_per_token * num_of_sequences + self.content_field = content_field + + def __call__(self): + def gen(): + for x in self: + yield x + + return gen() + + def __iter__(self): + for buffer in self._read_dataset_into_buffer(): + yield from self._tokenize(buffer) + + def _tokenize(self, buffer): + tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] + + all_token_ids = [] + for tokenized_input in tokenized_inputs: + all_token_ids.extend(tokenized_input + [self.concat_token_id]) + + for i in range(0, len(all_token_ids), self.seq_length): + input_ids = all_token_ids[i : i + self.seq_length] + + if len(input_ids) < self.seq_length: + input_ids = all_token_ids[-self.seq_length :] + + if len(input_ids) == self.seq_length: + self.current_size += 1 + yield dict(input_ids=input_ids, labels=input_ids) + + def _read_dataset_into_buffer(self): + iterator = iter(self.dataset) + more_examples = True + while more_examples: + buffer, buffer_len = [], 0 + while True: + if buffer_len >= self.max_buffer_size: + break + try: + buffer.append(next(iterator)[self.content_field]) + buffer_len += len(buffer[-1]) + except StopIteration: + if self.infinite: + iterator = iter(self.dataset) + else: + more_examples = False + break + yield buffer @dataclass @@ -40,6 +123,7 @@ class TrainLoraArguments: ], ) resume_from_checkpoint: str = None # either training checkpoint or final adapter + half: bool = True def parse_args() -> TrainLoraArguments: @@ -51,7 +135,7 @@ def train(args: TrainLoraArguments): gradient_accumulation_steps = args.batch_size // args.micro_batch_size model = AutoModelForCausalLM.from_pretrained( - args.base_model, torch_dtype=torch.float16 + args.base_model, torch_dtype=torch.float16 if args.half else torch.float32 ) tokenizer = AutoTokenizer.from_pretrained(args.base_model) @@ -66,7 +150,10 @@ def train(args: TrainLoraArguments): ) model = peft.get_peft_model(model, config) - data = load_dataset(tokenizer, args.data_path, seq_length=args.cutoff_len) + data_files = glob.glob(os.path.join(args.data_path, "*.jsonl")) + print("Collected data files...", data_files) + dataset = load_dataset("json", data_files=data_files)["train"] + data = Dataset.from_generator(ConstantLengthDataset(tokenizer, dataset)) resume_from_checkpoint = args.resume_from_checkpoint if resume_from_checkpoint: @@ -95,17 +182,17 @@ def train(args: TrainLoraArguments): train_data = train_val["train"].shuffle() val_data = train_val["test"].shuffle() - trainer = transformers.Trainer( + trainer = Trainer( model=model, train_dataset=train_data, eval_dataset=val_data, - args=transformers.TrainingArguments( + args=TrainingArguments( per_device_train_batch_size=args.micro_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, warmup_steps=100, num_train_epochs=args.num_epochs, learning_rate=args.learning_rate, - fp16=True, + fp16=args.half, logging_steps=10, evaluation_strategy="steps", save_strategy="steps", diff --git a/python/tabby/trainer/dataset.py b/python/tabby/trainer/dataset.py deleted file mode 100644 index b0fc850..0000000 --- a/python/tabby/trainer/dataset.py +++ /dev/null @@ -1,87 +0,0 @@ -import torch -from datasets import Dataset, load_from_disk - - -class ConstantLengthDataset: - """ - Iterable dataset that returns constant length chunks of tokens from stream of text files. - Args: - tokenizer (Tokenizer): The processor used for proccessing the data. - dataset (dataset.Dataset): Dataset with text files. - infinite (bool): If True the iterator is reset after dataset reaches end else stops. - seq_length (int): Length of token sequences to return. - num_of_sequences (int): Number of token sequences to keep in buffer. - chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer. - """ - - def __init__( - self, - tokenizer, - dataset, - infinite=False, - seq_length=1024, - num_of_sequences=1024, - chars_per_token=3.6, - content_field="content", - ): - self.tokenizer = tokenizer - self.concat_token_id = tokenizer.eos_token_id - self.dataset = dataset - self.seq_length = seq_length - self.infinite = infinite - self.current_size = 0 - self.max_buffer_size = seq_length * chars_per_token * num_of_sequences - self.content_field = content_field - - def __call__(self): - def gen(): - for x in self: - yield x - - return gen() - - def __iter__(self): - for buffer in self._read_dataset_into_buffer(): - yield from self._tokenize(buffer) - - def _tokenize(self, buffer): - tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"] - - all_token_ids = [] - for tokenized_input in tokenized_inputs: - all_token_ids.extend(tokenized_input + [self.concat_token_id]) - - for i in range(0, len(all_token_ids), self.seq_length): - input_ids = all_token_ids[i : i + self.seq_length] - - if len(input_ids) < self.seq_length: - input_ids = all_token_ids[-self.seq_length :] - - if len(input_ids) == self.seq_length: - self.current_size += 1 - yield dict(input_ids=input_ids, labels=input_ids) - - def _read_dataset_into_buffer(self): - iterator = iter(self.dataset) - more_examples = True - while more_examples: - buffer, buffer_len = [], 0 - while True: - if buffer_len >= self.max_buffer_size: - break - try: - buffer.append(next(iterator)[self.content_field]) - buffer_len += len(buffer[-1]) - except StopIteration: - if self.infinite: - iterator = iter(self.dataset) - else: - more_examples = False - break - yield buffer - - -def load_dataset(tokenizer, filepath, **kwargs): - ds = load_from_disk(filepath) - ds = Dataset.from_generator(ConstantLengthDataset(tokenizer, ds, **kwargs)) - return ds