Add LoRA Fine-tuning for private code repository (#22)
* Add bitandsands * Fix cudart in Dockerfile * Add ConstantLengthDataset in trainer * Add train_lora * Remove bnb * Remove useless importsadd-more-languages
parent
e992a0144b
commit
92eb2d54f5
|
|
@ -11,6 +11,10 @@ repos:
|
|||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v2.0.2
|
||||
hooks:
|
||||
- id: autoflake
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 22.10.0
|
||||
hooks:
|
||||
|
|
|
|||
|
|
@ -28,4 +28,8 @@ RUN poetry export --without-hashes -o requirements.txt
|
|||
|
||||
RUN --mount=type=cache,target=/root/.cache pip install -i $PYPI_INDEX_URL --extra-index-url https://pypi.org/simple --no-dependencies -r requirements.txt
|
||||
|
||||
## FIX bitandsands
|
||||
ENV LD_LIBRARY_PATH "$LD_LIBRARY_PATH:/opt/conda/lib"
|
||||
RUN ln -s /opt/conda/lib/libcudart.so.11.7.99 /opt/conda/lib/libcudart.so
|
||||
|
||||
COPY tabby ./tabby
|
||||
|
|
|
|||
|
|
@ -1,4 +1,10 @@
|
|||
UP_FLAGS := up --remove-orphans --remove-orphans --build
|
||||
UP_FLAGS := up --remove-orphans
|
||||
|
||||
build:
|
||||
docker-compose -f docker-compose.yml -f docker-compose.dev.yml build
|
||||
|
||||
up:
|
||||
docker-compose -f docker-compose.yml $(UP_FLAGS)
|
||||
|
||||
dev:
|
||||
docker-compose -f docker-compose.yml -f docker-compose.dev.yml $(UP_FLAGS)
|
||||
|
|
@ -7,4 +13,4 @@ dev-triton:
|
|||
docker-compose -f docker-compose.yml -f docker-compose.triton.yml -f docker-compose.dev.yml $(UP_FLAGS)
|
||||
|
||||
clean:
|
||||
docker-compose -f docker-compose.yml -f docker-compose.triton.yml -f docker-compose.dev.yml down
|
||||
docker-compose -f docker-compose.yml -f docker-compose.triton.yml -f docker-compose.dev.yml down --remove-orphans
|
||||
|
|
|
|||
|
|
@ -1,5 +1,34 @@
|
|||
# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "accelerate"
|
||||
version = "0.18.0"
|
||||
description = "Accelerate"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = [
|
||||
{file = "accelerate-0.18.0-py3-none-any.whl", hash = "sha256:41a84ac94407d7dcf030caf0cdadc70496594aec27ea680207bdb15b95f8a602"},
|
||||
{file = "accelerate-0.18.0.tar.gz", hash = "sha256:1dd36fd972de4a6d0cffe5e4d6d30622fd853765f773b5582cf0796deefe1016"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
psutil = "*"
|
||||
pyyaml = "*"
|
||||
torch = ">=1.4.0"
|
||||
|
||||
[package.extras]
|
||||
dev = ["black (>=23.1,<24.0)", "datasets", "deepspeed", "evaluate", "hf-doc-builder (>=0.3.0)", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "rich", "ruff (>=0.0.241)", "scikit-learn", "scipy", "tqdm", "transformers"]
|
||||
quality = ["black (>=23.1,<24.0)", "hf-doc-builder (>=0.3.0)", "ruff (>=0.0.241)"]
|
||||
rich = ["rich"]
|
||||
sagemaker = ["sagemaker"]
|
||||
test-dev = ["datasets", "deepspeed", "evaluate", "scikit-learn", "scipy", "tqdm", "transformers"]
|
||||
test-prod = ["parameterized", "pytest", "pytest-subtests", "pytest-xdist"]
|
||||
test-trackers = ["comet-ml", "tensorboard", "wandb"]
|
||||
testing = ["datasets", "deepspeed", "evaluate", "parameterized", "pytest", "pytest-subtests", "pytest-xdist", "scikit-learn", "scipy", "tqdm", "transformers"]
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
version = "3.8.4"
|
||||
|
|
@ -520,7 +549,7 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
|||
name = "cmake"
|
||||
version = "3.26.0"
|
||||
description = "CMake is an open-source, cross-platform family of tools designed to build, test and package software"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
|
|
@ -1293,7 +1322,7 @@ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-
|
|||
name = "lit"
|
||||
version = "15.0.7"
|
||||
description = "A Software Testing Tool"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
|
|
@ -1420,7 +1449,7 @@ files = [
|
|||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
description = "Python library for arbitrary-precision floating-point arithmetic"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
|
|
@ -1549,7 +1578,7 @@ dill = ">=0.3.6"
|
|||
name = "networkx"
|
||||
version = "3.0"
|
||||
description = "Python package for creating and manipulating graphs and networks"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
|
|
@ -1621,7 +1650,7 @@ files = [
|
|||
name = "nvidia-cublas-cu11"
|
||||
version = "11.10.3.66"
|
||||
description = "CUBLAS native runtime libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1637,7 +1666,7 @@ wheel = "*"
|
|||
name = "nvidia-cuda-cupti-cu11"
|
||||
version = "11.7.101"
|
||||
description = "CUDA profiling tools runtime libs."
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1653,7 +1682,7 @@ wheel = "*"
|
|||
name = "nvidia-cuda-nvrtc-cu11"
|
||||
version = "11.7.99"
|
||||
description = "NVRTC native runtime libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1670,7 +1699,7 @@ wheel = "*"
|
|||
name = "nvidia-cuda-runtime-cu11"
|
||||
version = "11.7.99"
|
||||
description = "CUDA Runtime native Libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1686,7 +1715,7 @@ wheel = "*"
|
|||
name = "nvidia-cudnn-cu11"
|
||||
version = "8.5.0.96"
|
||||
description = "cuDNN runtime libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1702,7 +1731,7 @@ wheel = "*"
|
|||
name = "nvidia-cufft-cu11"
|
||||
version = "10.9.0.58"
|
||||
description = "CUFFT native runtime libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1714,7 +1743,7 @@ files = [
|
|||
name = "nvidia-curand-cu11"
|
||||
version = "10.2.10.91"
|
||||
description = "CURAND native runtime libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1730,7 +1759,7 @@ wheel = "*"
|
|||
name = "nvidia-cusolver-cu11"
|
||||
version = "11.4.0.1"
|
||||
description = "CUDA solver native runtime libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1747,7 +1776,7 @@ wheel = "*"
|
|||
name = "nvidia-cusparse-cu11"
|
||||
version = "11.7.4.91"
|
||||
description = "CUSPARSE native runtime libraries"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1763,7 +1792,7 @@ wheel = "*"
|
|||
name = "nvidia-nccl-cu11"
|
||||
version = "2.14.3"
|
||||
description = "NVIDIA Collective Communication Library (NCCL) Runtime"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1774,7 +1803,7 @@ files = [
|
|||
name = "nvidia-nvtx-cu11"
|
||||
version = "11.7.91"
|
||||
description = "NVIDIA Tools Extension"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3"
|
||||
files = [
|
||||
|
|
@ -1846,6 +1875,31 @@ pytz = ">=2020.1"
|
|||
[package.extras]
|
||||
test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"]
|
||||
|
||||
[[package]]
|
||||
name = "peft"
|
||||
version = "0.2.0"
|
||||
description = ""
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7.0"
|
||||
files = []
|
||||
develop = false
|
||||
|
||||
[package.dependencies]
|
||||
accelerate = "*"
|
||||
numpy = ">=1.17"
|
||||
packaging = ">=20.0"
|
||||
psutil = "*"
|
||||
pyyaml = "*"
|
||||
torch = ">=1.13.0"
|
||||
transformers = "*"
|
||||
|
||||
[package.source]
|
||||
type = "git"
|
||||
url = "https://github.com/huggingface/peft.git"
|
||||
reference = "v0.2.0"
|
||||
resolved_reference = "a478ab9bce252722817b5a2f10a1c0089ce9980c"
|
||||
|
||||
[[package]]
|
||||
name = "pillow"
|
||||
version = "9.4.0"
|
||||
|
|
@ -2004,6 +2058,33 @@ files = [
|
|||
{file = "protobuf-3.20.3.tar.gz", hash = "sha256:2e3427429c9cffebf259491be0af70189607f365c2f41c7c3764af6f337105f2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "psutil"
|
||||
version = "5.9.4"
|
||||
description = "Cross-platform lib for process and system monitoring in Python."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
||||
files = [
|
||||
{file = "psutil-5.9.4-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c1ca331af862803a42677c120aff8a814a804e09832f166f226bfd22b56feee8"},
|
||||
{file = "psutil-5.9.4-cp27-cp27m-manylinux2010_i686.whl", hash = "sha256:68908971daf802203f3d37e78d3f8831b6d1014864d7a85937941bb35f09aefe"},
|
||||
{file = "psutil-5.9.4-cp27-cp27m-manylinux2010_x86_64.whl", hash = "sha256:3ff89f9b835100a825b14c2808a106b6fdcc4b15483141482a12c725e7f78549"},
|
||||
{file = "psutil-5.9.4-cp27-cp27m-win32.whl", hash = "sha256:852dd5d9f8a47169fe62fd4a971aa07859476c2ba22c2254d4a1baa4e10b95ad"},
|
||||
{file = "psutil-5.9.4-cp27-cp27m-win_amd64.whl", hash = "sha256:9120cd39dca5c5e1c54b59a41d205023d436799b1c8c4d3ff71af18535728e94"},
|
||||
{file = "psutil-5.9.4-cp27-cp27mu-manylinux2010_i686.whl", hash = "sha256:6b92c532979bafc2df23ddc785ed116fced1f492ad90a6830cf24f4d1ea27d24"},
|
||||
{file = "psutil-5.9.4-cp27-cp27mu-manylinux2010_x86_64.whl", hash = "sha256:efeae04f9516907be44904cc7ce08defb6b665128992a56957abc9b61dca94b7"},
|
||||
{file = "psutil-5.9.4-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:54d5b184728298f2ca8567bf83c422b706200bcbbfafdc06718264f9393cfeb7"},
|
||||
{file = "psutil-5.9.4-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:16653106f3b59386ffe10e0bad3bb6299e169d5327d3f187614b1cb8f24cf2e1"},
|
||||
{file = "psutil-5.9.4-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54c0d3d8e0078b7666984e11b12b88af2db11d11249a8ac8920dd5ef68a66e08"},
|
||||
{file = "psutil-5.9.4-cp36-abi3-win32.whl", hash = "sha256:149555f59a69b33f056ba1c4eb22bb7bf24332ce631c44a319cec09f876aaeff"},
|
||||
{file = "psutil-5.9.4-cp36-abi3-win_amd64.whl", hash = "sha256:fd8522436a6ada7b4aad6638662966de0d61d241cb821239b2ae7013d41a43d4"},
|
||||
{file = "psutil-5.9.4-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:6001c809253a29599bc0dfd5179d9f8a5779f9dffea1da0f13c53ee568115e1e"},
|
||||
{file = "psutil-5.9.4.tar.gz", hash = "sha256:3d7f9739eb435d4b1338944abe23f49584bde5395f27487d2ee25ad9a8774a62"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["enum34", "ipaddress", "mock", "pywin32", "wmi"]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "11.0.0"
|
||||
|
|
@ -2635,7 +2716,7 @@ snowflake = ["snowflake-snowpark-python"]
|
|||
name = "sympy"
|
||||
version = "1.11.1"
|
||||
description = "Computer algebra system (CAS) in Python"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
|
|
@ -2729,7 +2810,7 @@ files = [
|
|||
name = "torch"
|
||||
version = "2.0.0"
|
||||
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
files = [
|
||||
|
|
@ -2891,7 +2972,7 @@ vision = ["Pillow"]
|
|||
name = "triton"
|
||||
version = "2.0.0"
|
||||
description = "A language and compiler for custom Deep Learning operations"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
|
|
@ -3104,7 +3185,7 @@ watchmedo = ["PyYAML (>=3.10)"]
|
|||
name = "wheel"
|
||||
version = "0.40.0"
|
||||
description = "A built-package format for Python"
|
||||
category = "dev"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
|
|
@ -3412,4 +3493,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "c4e4670acf2dd356d6e625d590ec109eee7081c065fc54ce97e9f69bdcb21844"
|
||||
content-hash = "895f90bc2f1bbe6847af480e6ae049b31c463ab134af7f5895d618d4662da551"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ loguru = "^0.6.0"
|
|||
gitup = "^0.5.1"
|
||||
toml = "^0.10.2"
|
||||
gitpython = "^3.1.31"
|
||||
peft = {git = "https://github.com/huggingface/peft.git", rev = "v0.2.0"}
|
||||
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ class ServiceInfo:
|
|||
def is_health(self) -> bool:
|
||||
try:
|
||||
return requests.get(self.url).status_code == 200
|
||||
except ConnectionError as e:
|
||||
except ConnectionError:
|
||||
return False
|
||||
|
||||
@property
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import uvicorn
|
||||
|
|
|
|||
|
|
@ -1,7 +1,5 @@
|
|||
import json
|
||||
import os
|
||||
import shutil
|
||||
from typing import List
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import json
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ import configparser
|
|||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
|
@ -91,7 +90,6 @@ def split_and_convert(args):
|
|||
|
||||
if os.path.exists(saved_dir) == False:
|
||||
os.makedirs(saved_dir)
|
||||
ckpt_name = args.in_file
|
||||
|
||||
t_gpu_num = args.trained_gpu_num
|
||||
i_gpu_num = args.infer_gpu_num
|
||||
|
|
|
|||
|
|
@ -0,0 +1,87 @@
|
|||
import torch
|
||||
from datasets import Dataset, load_from_disk
|
||||
|
||||
|
||||
class ConstantLengthDataset:
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
infinite (bool): If True the iterator is reset after dataset reaches end else stops.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
num_of_sequences (int): Number of token sequences to keep in buffer.
|
||||
chars_per_token (int): Number of characters per token used to estimate number of tokens in text buffer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer,
|
||||
dataset,
|
||||
infinite=False,
|
||||
seq_length=1024,
|
||||
num_of_sequences=1024,
|
||||
chars_per_token=3.6,
|
||||
content_field="content",
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.concat_token_id = tokenizer.eos_token_id
|
||||
self.dataset = dataset
|
||||
self.seq_length = seq_length
|
||||
self.infinite = infinite
|
||||
self.current_size = 0
|
||||
self.max_buffer_size = seq_length * chars_per_token * num_of_sequences
|
||||
self.content_field = content_field
|
||||
|
||||
def __call__(self):
|
||||
def gen():
|
||||
for x in self:
|
||||
yield x
|
||||
|
||||
return gen()
|
||||
|
||||
def __iter__(self):
|
||||
for buffer in self._read_dataset_into_buffer():
|
||||
yield from self._tokenize(buffer)
|
||||
|
||||
def _tokenize(self, buffer):
|
||||
tokenized_inputs = self.tokenizer(buffer, truncation=False)["input_ids"]
|
||||
|
||||
all_token_ids = []
|
||||
for tokenized_input in tokenized_inputs:
|
||||
all_token_ids.extend(tokenized_input + [self.concat_token_id])
|
||||
|
||||
for i in range(0, len(all_token_ids), self.seq_length):
|
||||
input_ids = all_token_ids[i : i + self.seq_length]
|
||||
|
||||
if len(input_ids) < self.seq_length:
|
||||
input_ids = all_token_ids[-self.seq_length :]
|
||||
|
||||
if len(input_ids) == self.seq_length:
|
||||
self.current_size += 1
|
||||
yield dict(input_ids=input_ids, labels=input_ids)
|
||||
|
||||
def _read_dataset_into_buffer(self):
|
||||
iterator = iter(self.dataset)
|
||||
more_examples = True
|
||||
while more_examples:
|
||||
buffer, buffer_len = [], 0
|
||||
while True:
|
||||
if buffer_len >= self.max_buffer_size:
|
||||
break
|
||||
try:
|
||||
buffer.append(next(iterator)[self.content_field])
|
||||
buffer_len += len(buffer[-1])
|
||||
except StopIteration:
|
||||
if self.infinite:
|
||||
iterator = iter(self.dataset)
|
||||
else:
|
||||
more_examples = False
|
||||
break
|
||||
yield buffer
|
||||
|
||||
|
||||
def load_dataset(tokenizer, filepath, **kwargs):
|
||||
ds = load_from_disk(filepath)
|
||||
ds = Dataset.from_generator(ConstantLengthDataset(tokenizer, ds, **kwargs))
|
||||
return ds
|
||||
|
|
@ -0,0 +1,137 @@
|
|||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List
|
||||
|
||||
import peft
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
from .dataset import load_dataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainLoraArguments:
|
||||
data_path: str = field(metadata={"help": "Dataset dir for training / eval "})
|
||||
output_dir: str = field(metadata={"help": "Output dir for checkpoint"})
|
||||
base_model: str = field(
|
||||
default="TabbyML/J-350M", metadata={"help": "Base model for fine-tuning"}
|
||||
)
|
||||
|
||||
batch_size: int = 128
|
||||
micro_batch_size: int = 4
|
||||
num_epochs: int = 3
|
||||
learning_rate: float = 3e-4
|
||||
cutoff_len: int = 256
|
||||
|
||||
# Evaluations
|
||||
val_set_size: int = 2000
|
||||
eval_steps: int = 200
|
||||
|
||||
# Lora Hyperparams
|
||||
lora_r: int = 8
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.05
|
||||
lora_target_modules: List[str] = (
|
||||
[
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
],
|
||||
)
|
||||
resume_from_checkpoint: str = None # either training checkpoint or final adapter
|
||||
|
||||
|
||||
def parse_args() -> TrainLoraArguments:
|
||||
parser = HfArgumentParser(TrainLoraArguments)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def train(args: TrainLoraArguments):
|
||||
gradient_accumulation_steps = args.batch_size // args.micro_batch_size
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.base_model, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.base_model)
|
||||
|
||||
config = peft.LoraConfig(
|
||||
r=args.lora_r,
|
||||
lora_alpha=args.lora_alpha,
|
||||
target_modules=args.lora_target_modules,
|
||||
lora_dropout=args.lora_dropout,
|
||||
bias="none",
|
||||
task_type=peft.TaskType.CAUSAL_LM,
|
||||
)
|
||||
model = peft.get_peft_model(model, config)
|
||||
|
||||
data = load_dataset(tokenizer, args.data_path, seq_length=args.cutoff_len)
|
||||
|
||||
resume_from_checkpoint = args.resume_from_checkpoint
|
||||
if resume_from_checkpoint:
|
||||
# Check the available weights and load them
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "pytorch_model.bin"
|
||||
) # Full checkpoint
|
||||
if not os.path.exists(checkpoint_name):
|
||||
checkpoint_name = os.path.join(
|
||||
resume_from_checkpoint, "adapter_model.bin"
|
||||
) # only LoRA model - LoRA config above has to fit
|
||||
resume_from_checkpoint = False # So the trainer won't try loading its state
|
||||
# The two files above have a different name depending on how they were saved, but are actually the same.
|
||||
if os.path.exists(checkpoint_name):
|
||||
print(f"Restarting from {checkpoint_name}")
|
||||
adapters_weights = torch.load(checkpoint_name)
|
||||
model = peft.set_peft_model_state_dict(model, adapters_weights)
|
||||
else:
|
||||
print(f"Checkpoint {checkpoint_name} not found")
|
||||
|
||||
model.print_trainable_parameters() # Be more transparent about the % of trainable params.
|
||||
|
||||
train_val = data.train_test_split(
|
||||
test_size=args.val_set_size, shuffle=True, seed=42
|
||||
)
|
||||
train_data = train_val["train"].shuffle()
|
||||
val_data = train_val["test"].shuffle()
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
model=model,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
args=transformers.TrainingArguments(
|
||||
per_device_train_batch_size=args.micro_batch_size,
|
||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||
warmup_steps=100,
|
||||
num_train_epochs=args.num_epochs,
|
||||
learning_rate=args.learning_rate,
|
||||
fp16=True,
|
||||
logging_steps=10,
|
||||
evaluation_strategy="steps",
|
||||
save_strategy="steps",
|
||||
eval_steps=args.eval_steps,
|
||||
save_steps=args.eval_steps,
|
||||
output_dir=args.output_dir,
|
||||
save_total_limit=3,
|
||||
load_best_model_at_end=True,
|
||||
),
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
old_state_dict = model.state_dict
|
||||
model.state_dict = (
|
||||
lambda self, *_, **__: peft.get_peft_model_state_dict(self, old_state_dict())
|
||||
).__get__(model, type(model))
|
||||
|
||||
model = torch.compile(model)
|
||||
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
|
||||
model.save_pretrained(args.output_dir)
|
||||
|
||||
print("\n If there's a warning about missing keys above, please disregard :)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
train(args)
|
||||
Loading…
Reference in New Issue