text-generation-inference
/
Dockerfile
254 строки · 8.4 Кб
1# Rust builder
2FROM lukemathwalker/cargo-chef:latest-rust-1.75 AS chef
3WORKDIR /usr/src
4
5ARG CARGO_REGISTRIES_CRATES_IO_PROTOCOL=sparse
6
7FROM chef as planner
8COPY Cargo.toml Cargo.toml
9COPY rust-toolchain.toml rust-toolchain.toml
10COPY proto proto
11COPY benchmark benchmark
12COPY router router
13COPY launcher launcher
14RUN cargo chef prepare --recipe-path recipe.json
15
16FROM chef AS builder
17
18ARG GIT_SHA
19ARG DOCKER_LABEL
20
21RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
22curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
23unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
24unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
25rm -f $PROTOC_ZIP
26
27COPY --from=planner /usr/src/recipe.json recipe.json
28RUN cargo chef cook --release --recipe-path recipe.json
29
30COPY Cargo.toml Cargo.toml
31COPY rust-toolchain.toml rust-toolchain.toml
32COPY proto proto
33COPY benchmark benchmark
34COPY router router
35COPY launcher launcher
36RUN cargo build --release
37
38# Python builder
39# Adapted from: https://github.com/pytorch/pytorch/blob/master/Dockerfile
40FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 as pytorch-install
41
42ARG PYTORCH_VERSION=2.1.1
43ARG PYTHON_VERSION=3.10
44# Keep in sync with `server/pyproject.toml
45ARG CUDA_VERSION=12.1
46ARG MAMBA_VERSION=23.3.1-1
47ARG CUDA_CHANNEL=nvidia
48ARG INSTALL_CHANNEL=pytorch
49# Automatically set by buildx
50ARG TARGETPLATFORM
51
52ENV PATH /opt/conda/bin:$PATH
53
54RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
55build-essential \
56ca-certificates \
57ccache \
58curl \
59git && \
60rm -rf /var/lib/apt/lists/*
61
62# Install conda
63# translating Docker's TARGETPLATFORM into mamba arches
64RUN case ${TARGETPLATFORM} in \
65"linux/arm64") MAMBA_ARCH=aarch64 ;; \
66*) MAMBA_ARCH=x86_64 ;; \
67esac && \
68curl -fsSL -v -o ~/mambaforge.sh -O "https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-${MAMBA_ARCH}.sh"
69RUN chmod +x ~/mambaforge.sh && \
70bash ~/mambaforge.sh -b -p /opt/conda && \
71rm ~/mambaforge.sh
72
73# Install pytorch
74# On arm64 we exit with an error code
75RUN case ${TARGETPLATFORM} in \
76"linux/arm64") exit 1 ;; \
77*) /opt/conda/bin/conda update -y conda && \
78/opt/conda/bin/conda install -c "${INSTALL_CHANNEL}" -c "${CUDA_CHANNEL}" -y "python=${PYTHON_VERSION}" "pytorch=$PYTORCH_VERSION" "pytorch-cuda=$(echo $CUDA_VERSION | cut -d'.' -f 1-2)" ;; \
79esac && \
80/opt/conda/bin/conda clean -ya
81
82# CUDA kernels builder image
83FROM pytorch-install as kernel-builder
84
85ARG MAX_JOBS=8
86
87RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
88ninja-build \
89&& rm -rf /var/lib/apt/lists/*
90
91# Build Flash Attention CUDA kernels
92FROM kernel-builder as flash-att-builder
93
94WORKDIR /usr/src
95
96COPY server/Makefile-flash-att Makefile
97
98# Build specific version of flash attention
99RUN make build-flash-attention
100
101# Build Flash Attention v2 CUDA kernels
102FROM kernel-builder as flash-att-v2-builder
103
104WORKDIR /usr/src
105
106COPY server/Makefile-flash-att-v2 Makefile
107
108# Build specific version of flash attention v2
109RUN make build-flash-attention-v2-cuda
110
111# Build Transformers exllama kernels
112FROM kernel-builder as exllama-kernels-builder
113WORKDIR /usr/src
114COPY server/exllama_kernels/ .
115
116RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
117
118# Build Transformers exllama kernels
119FROM kernel-builder as exllamav2-kernels-builder
120WORKDIR /usr/src
121COPY server/exllamav2_kernels/ .
122
123# Build specific version of transformers
124RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" python setup.py build
125
126# Build Transformers awq kernels
127FROM kernel-builder as awq-kernels-builder
128WORKDIR /usr/src
129COPY server/Makefile-awq Makefile
130# Build specific version of transformers
131RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-awq
132
133# Build eetq kernels
134FROM kernel-builder as eetq-kernels-builder
135WORKDIR /usr/src
136COPY server/Makefile-eetq Makefile
137# Build specific version of transformers
138RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq
139
140# Build Transformers CUDA kernels
141FROM kernel-builder as custom-kernels-builder
142WORKDIR /usr/src
143COPY server/custom_kernels/ .
144# Build specific version of transformers
145RUN python setup.py build
146
147# Build vllm CUDA kernels
148FROM kernel-builder as vllm-builder
149
150WORKDIR /usr/src
151
152COPY server/Makefile-vllm Makefile
153
154# Build specific version of vllm
155RUN make build-vllm-cuda
156
157# Build mamba kernels
158FROM kernel-builder as mamba-builder
159WORKDIR /usr/src
160COPY server/Makefile-selective-scan Makefile
161RUN make build-all
162
163# Build megablocks
164FROM kernel-builder as megablocks-builder
165
166RUN pip install git+https://github.com/OlivierDehaene/megablocks@181709df192de9a941fdf3a641cdc65a0462996e
167
168# Text Generation Inference base image
169FROM nvidia/cuda:12.1.0-base-ubuntu22.04 as base
170
171# Conda env
172ENV PATH=/opt/conda/bin:$PATH \
173CONDA_PREFIX=/opt/conda
174
175# Text Generation Inference base env
176ENV HUGGINGFACE_HUB_CACHE=/data \
177HF_HUB_ENABLE_HF_TRANSFER=1 \
178PORT=80
179
180WORKDIR /usr/src
181
182RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
183libssl-dev \
184ca-certificates \
185make \
186curl \
187&& rm -rf /var/lib/apt/lists/*
188
189# Copy conda with PyTorch and Megablocks installed
190COPY --from=megablocks-builder /opt/conda /opt/conda
191
192# Copy build artifacts from flash attention builder
193COPY --from=flash-att-builder /usr/src/flash-attention/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
194COPY --from=flash-att-builder /usr/src/flash-attention/csrc/layer_norm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
195COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
196
197# Copy build artifacts from flash attention v2 builder
198COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
199
200# Copy build artifacts from custom kernels builder
201COPY --from=custom-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
202# Copy build artifacts from exllama kernels builder
203COPY --from=exllama-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
204# Copy build artifacts from exllamav2 kernels builder
205COPY --from=exllamav2-kernels-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
206# Copy build artifacts from awq kernels builder
207COPY --from=awq-kernels-builder /usr/src/llm-awq/awq/kernels/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
208# Copy build artifacts from eetq kernels builder
209COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
210
211# Copy builds artifacts from vllm builder
212COPY --from=vllm-builder /usr/src/vllm/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages
213
214# Copy build artifacts from mamba builder
215COPY --from=mamba-builder /usr/src/mamba/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
216COPY --from=mamba-builder /usr/src/causal-conv1d/build/lib.linux-x86_64-cpython-310/ /opt/conda/lib/python3.10/site-packages
217
218# Install flash-attention dependencies
219RUN pip install einops --no-cache-dir
220
221# Install server
222COPY proto proto
223COPY server server
224COPY server/Makefile server/Makefile
225RUN cd server && \
226make gen-server && \
227pip install -r requirements_cuda.txt && \
228pip install ".[bnb, accelerate, quantize, peft, outlines]" --no-cache-dir
229
230# Install benchmarker
231COPY --from=builder /usr/src/target/release/text-generation-benchmark /usr/local/bin/text-generation-benchmark
232# Install router
233COPY --from=builder /usr/src/target/release/text-generation-router /usr/local/bin/text-generation-router
234# Install launcher
235COPY --from=builder /usr/src/target/release/text-generation-launcher /usr/local/bin/text-generation-launcher
236
237RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
238build-essential \
239g++ \
240&& rm -rf /var/lib/apt/lists/*
241
242# AWS Sagemaker compatible image
243FROM base as sagemaker
244
245COPY sagemaker-entrypoint.sh entrypoint.sh
246RUN chmod +x entrypoint.sh
247
248ENTRYPOINT ["./entrypoint.sh"]
249
250# Final image
251FROM base
252
253ENTRYPOINT ["text-generation-launcher"]
254CMD ["--json-output"]
255