• 文档 >
  • 利用 Torch 编译的 RAG 在 AWS Graviton 上增强 LLM 服务
快捷方式

利用 Torch 编译的 RAG 在 AWS Graviton 上增强 LLM 服务

之前,我们已经演示了如何使用 TorchServe 部署 Llama。仅部署 LLM 会有一些限制,例如缺乏最新信息和有限的特定领域知识。检索增强生成 (RAG) 是一种技术,可通过提供最新、相关信息的上下文来提高 LLM 的准确性和可靠性。这篇博文说明了如何在基于微服务的架构中实现 RAG 与 LLM 并存,这增强了可扩展性并加快了开发速度。此外,通过利用基于 CPU 的 RAG 与 AWS Graviton,客户可以有效地利用计算资源,最终节省成本。

问题

考虑一下用户查询 TorchServe 端点(提供 Llama 3 服务,Llama3-8b-instruct)的简单设计,如图 1 所示。部署此端点的说明可以在此链接中找到。该模型在 NVIDIA GPU(A10Gx4)上部署,未进行量化,可在 AWS 上作为g5.12xlarge 实例使用。

LLM Deployment

假设我们想了解 Llama 3 的最新信息,我们将以下查询发送到 TorchServe Llama 端点。

Question: What's new with Llama 3.1?

模型返回的响应谈到了一个名为 Llama 3.1 的数据可视化工具,这并非我们预期的结果。

Answer:  (Updated)
We've been busy bees in the Llama office, and we're excited to share the latest updates with you!
Llama 3.1 brings a bunch of new features and improvements to make your workflow even smoother and more efficient. Here are some of the highlights:
**New Features:**
1. **Customizable Columns**: You can now customize the columns in your Llama tables to fit your specific needs. This includes adding, removing, and rearranging columns, as well as setting default values for certain columns.
2. **Advanced Filtering**: Llama 3.1 introduces advanced filtering capabilities, allowing you to filter your data using a variety of conditions, such as date ranges, text matches, and more.
3. **Conditional Formatting**: You can now apply conditional formatting to your data, making it easier to visualize and analyze your results.
4. **Improved Data Import**: We've streamlined the data import process, making it easier to import data from various sources, including CSV

检索增强生成

大型语言模型 (LLM),例如 Llama,擅长执行许多复杂的文本生成任务。但是,当将 LLM 用于特定领域时,它们确实存在一些限制,例如

  • 过时信息:由于模型是在较早日期训练的(又称知识截止日期),因此可能存在模型未意识到的领域进步。

  • 缺乏特定领域知识:当将 LLM 用于特定领域任务时,LLM 可能会给出不准确的答案,因为特定领域知识可能无法轻松获得。

检索增强生成 (RAG) 是一种用于解决这些限制的技术。RAG 通过使用与查询相关的最新信息来增强 LLM,从而提高 LLM 的准确性。RAG 通过将数据源拆分为指定大小的块、对这些块进行索引以及根据查询检索相关块来实现此目的。获得的信息用作上下文,以增强发送到 LLM 的查询。

LangChain 是一个流行的框架,用于使用 RAG 构建 LLM 应用程序。

虽然 LLM 推理需要昂贵的 ML 加速器,但 RAG 端点可以部署在具有成本效益的 CPU 上,同时仍然满足用例延迟要求。此外,将 RAG 端点卸载到 CPU 上可以实现将 LLM 和业务基础设施解耦的微服务架构,并独立扩展它们。在下面的部分中,我们演示了如何在基于 linux-aarch64 的 AWS Graviton 上部署 RAG。此外,我们还展示了如何使用torch.compile提高 RAG 端点的吞吐量。基本 RAG 工作流包括两个步骤

索引

本示例中提供的上下文是一个网页URL。我们加载 URL 中的内容,并递归包含子页面。文档被拆分为较小的块以进行有效处理。这些块使用嵌入模型进行编码并存储在向量数据库中,从而实现高效的搜索和检索。我们在嵌入模型上使用torch.compile来加快推理速度。您可以阅读有关将torch.compile与 AWS Graviton 结合使用的更多信息此处

from bs4 import BeautifulSoup as Soup
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader

import torch

# Enable AWS Graviton specific torch.compile optimizations
import torch._inductor.config as config
config.cpp.weight_prepack=True
config.freezing=True

class CustomEmbedding(HuggingFaceEmbeddings):
    tokenizer: Any

    def __init__(self, tokenizer: Any, **kwargs: Any):
        """Initialize the sentence_transformer."""
        super().__init__(**kwargs)

        # Load model from HuggingFace Hub
        self.tokenizer = tokenizer
        self.client = AutoModel.from_pretrained('sentence-transformers/all-mpnet-base-v2')
        self.client = torch.compile(self.client)
    class Config:
        arbitrary_types_allowed = True



    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """Compute doc embeddings using a HuggingFace transformer model.

        Args:
            texts: The list of texts to embed.

        Returns:
            List of embeddings, one for each text.
        """
        import sentence_transformers

        texts = list(map(lambda x: x.replace("\n", " "), texts))

        # Tokenize sentences
        encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

        embeddings = self.client(
           **encoded_input
        )
        embeddings = embeddings.pooler_output.detach().numpy()

        return embeddings.tolist()


# 1. Load the url and its child pages
url="https://hugging-face.cn/blog/llama3"
loader = RecursiveUrlLoader(
    url=url, max_depth=3, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()

# 2. Split the document into chunks with a specified chunk size
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
all_splits = text_splitter.split_documents(docs)

# 3. Store the document into a vector store with a specific embedding model
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-mpnet-base-v2')
model = CustomEmbedding(tokenizer)

vectorstore = FAISS.from_documents(all_splits, model)

检索

对于用户发送的每个查询,我们都会在向量数据库中对查询进行相似性搜索,并获取 N(此处 N=5)个最接近的文档块。

docs = vectorstore.similarity_search(query, k=5)

提示工程

RAG 与 LLM 的典型实现使用 langchain 来链接 RAG 和 LLM 管道,并使用查询对链调用 invoke 方法。

已发布的带有 TorchServe 的 Llama 端点示例期望文本提示作为输入,并使用HuggingFace API 处理查询。为了使 RAG 设计兼容,我们需要从 RAG 端点返回文本提示。

本节介绍了如何设计 Llama 端点期望的提示,包括相关上下文。在后台,LangChain 具有 Llama 的PromptTemplate。通过使用以下调试语句执行以上代码,我们可以确定发送到 Llama 的提示。

import langchain
langchain.debug = True

我们从检索部分返回的文档中提取文本,并按如下方式对 Llama 进行最终提示工程

from langchain.prompts import PromptTemplate
from langchain_core.prompts import format_document
question="What's new with Llama 3?"

doc_prompt = PromptTemplate.from_template("{page_content}")
context = ""
for doc in docs:
    context += f"\n{format_document(doc, doc_prompt)}\n"

prompt = f"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."\
         f"\n\n{context}\n\nQuestion: {question}"\
         f"\nHelpful Answer:"

AWS Graviton 特定优化

为了利用 AWS Graviton 上的 RAG 性能优化,我们可以设置以下优化;优化详细信息在此博文中提到。还有一个教程讨论了这些优化

export TORCH_MKLDNN_MATMUL_MIN_DIM=1024
export LRU_CACHE_CAPACITY=1024
export THP_MEM_ALLOC_ENABLE=1
export DNNL_DEFAULT_FPMATH_MODE=BF16

为了准确测量与 PyTorch 急切模式相比,使用 torch.compile 获得的性能提升,我们还设置了

export OMP_NUM_THREADS=1

部署 RAG

尽管 TorchServe 在同一计算实例上提供了多模型端点支持,但我们在 AWS Graviton 上部署了 RAG 端点。由于 RAG 的计算量并不大,因此我们可以使用 CPU 实例进行部署,以提供经济高效的解决方案。

要使用 TorchServe 部署 RAG,我们需要以下内容

  • requirements.txt

langchain
Langchain_community
sentence-transformers
faiss-cpu
bs4

这可以与 install_py_dep_per_model=true一起在 config.properties中使用,以动态安装所需的库

  • rag-config.yaml

我们将用于索引和检索的参数传递到rag-config.yaml 中,该文件用于创建 MAR 文件。通过使这些参数可配置,我们可以使用不同的 yaml 文件为不同的任务创建多个 RAG 端点。

# TorchServe frontend parameters
minWorkers: 1
maxWorkers: 1
responseTimeout: 120
handler:
    url_to_scrape: "https://hugging-face.cn/blog/llama3"
    chunk_size: 1000
    chunk_overlap: 0
    model_path: "model/models--sentence-transformers--all-mpnet-base-v2/snapshots/84f2bcc00d77236f9e89c8a360a00fb1139bf47d"
  • rag_handler.py

我们定义一个处理程序文件,其中包含一个从BaseHandler派生的类。这个类需要定义四个方法:initializepreprocessinferencepostprocess。索引部分定义在initialize方法中。检索部分在inference方法中,提示工程部分在postprocess方法中。我们使用计时函数来确定处理每个方法所花费的时间。

import torch
import transformers
from bs4 import BeautifulSoup as Soup
from hf_custom_embeddings import CustomEmbedding
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import format_document

from ts.torch_handler.base_handler import BaseHandler


class RAGHandler(BaseHandler):
    """
    RAG handler class retrieving documents from a url, encoding & storing in a vector database.
    For a given query, it returns the closest matching documents.
    """

    def __init__(self):
        super(RAGHandler, self).__init__()
        self.vectorstore = None
        self.initialized = False
        self.N = 5

    @torch.inference_mode
    def initialize(self, ctx):
        url = ctx.model_yaml_config["handler"]["url_to_scrape"]
        chunk_size = ctx.model_yaml_config["handler"]["chunk_size"]
        chunk_overlap = ctx.model_yaml_config["handler"]["chunk_overlap"]
        model_path = ctx.model_yaml_config["handler"]["model_path"]

        loader = RecursiveUrlLoader(
            url=url, max_depth=3, extractor=lambda x: Soup(x, "html.parser").text
        )
        docs = loader.load()

        # Split the document into chunks with a specified chunk size
        text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size, chunk_overlap=chunk_overlap
        )
        all_splits = text_splitter.split_documents(docs)

        # Store the document into a vector store with a specific embedding model
        self.vectorstore = FAISS.from_documents(
            all_splits, CustomEmbedding(model_path=model_path)
        )

    def preprocess(self, requests):
        assert len(requests) == 1, "Expecting batch_size = 1"
        inputs = []
        for request in requests:
            input_text = request.get("data") or request.get("body")
            if isinstance(input_text, (bytes, bytearray)):
                input_text = input_text.decode("utf-8")
            inputs.append(input_text)
        return inputs[0]

    @torch.inference_mode
    def inference(self, data, *args, **kwargs):
        searchDocs = self.vectorstore.similarity_search(data, k=self.N)
        return (searchDocs, data)

    def postprocess(self, data):
        docs, question = data[0], data[1]
        doc_prompt = PromptTemplate.from_template("{page_content}")
        context = ""
        for doc in docs:
            context += f"\n{format_document(doc, doc_prompt)}\n"

        prompt = (
            f"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer."
            f"\n\n{context}\n\nQuestion: {question}"
            f"\nHelpful Answer:"
        )
        return [prompt]

性能基准测试

我们使用ab工具来衡量RAG端点的性能。

python benchmarks/auto_benchmark.py --input /home/ubuntu/serve/examples/usecases/RAG_based_LLM_serving
benchmark_profile.yaml --skip true

我们重复运行,并结合使用OMP_NUM_THREADS和PyTorch Eager/ torch.compile。

结果

我们在AWS EC2 m7g.4xlarge实例上观察到以下吞吐量。

RAG Throughput

我们观察到使用torch.compile可以将RAG端点的吞吐量提高35%。吞吐量的规模(Eager或Compile)表明,在CPU设备上部署RAG对于与部署在GPU实例上的LLM一起使用是可行的。RAG端点不会成为LLM部署中的瓶颈。

RAG + LLM部署

使用基于RAG的LLM服务的端到端解决方案的系统架构如图2所示。

RAG + LLM Deployment

完整部署的步骤在部署指南中提到。

下面显示了可以将RAG端点与Llama端点链接的代码片段。

import requests

prompt="What's new with Llama 3.1?"

RAG_EP = "http://<RAG Endpoint IP Address>:8080/predictions/rag"
LLAMA_EP = "http://<LLAMA Endpoint IP Address>:8080/predictions/llama3-8b-instruct"
# Get response from RAG
response = requests.post(url=RAG_EP, data=prompt)
# Get response from Llama
response = requests.post(url=LLAMA_EP, data=response.text.encode('utf-8'))
print(f"Question: {prompt}")
print(f"Answer: {response.text}")

示例输出

Question: What's new with Llama 3.1?
Answer:  Llama 3.1 has a large context length of 128K tokens, multilingual capabilities, tool usage capabilities, a very large dense model of 405 billion parameters, and a more permissive license. It also introduces six new open LLM models based on the Llama 3 architecture, and continues to use Grouped-Query Attention (GQA) for efficient representation. The new tokenizer expands the vocabulary size to 128,256, and the 8B version of the model now uses GQA. The license allows using model outputs to improve other LLMs.
Question: What's new with Llama 2?
Answer:  There is no mention of Llama 2 in the provided context. The text only discusses Llama 3.1 and its features. Therefore, it is not possible to determine what is new with Llama 2. I don't know.

结论

在本博文中,我们展示了如何使用TorchServe部署RAG端点,如何使用torch.compile提高吞吐量,以及如何改进Llama端点生成的响应。使用图2中描述的架构,我们可以减少LLM的幻觉。
我们还展示了如何将RAG端点部署在使用AWS Graviton的CPU上,而Llama端点仍然部署在GPU上。这种基于微服务的RAG解决方案可以有效利用计算资源,从而为客户节省潜在的成本。

文档

访问PyTorch的全面开发者文档

查看文档

教程

获取初学者和高级开发人员的深入教程

查看教程

资源

查找开发资源并获得问题的解答

查看资源