快捷方式

提示模板

提示模板是结构化文本模板,用于格式化用户提示,以优化模型在特定任务上的性能。它们可以服务于多种目的:

  1. 模型特定的模板,每当模型接收到提示时都需要使用,例如指令微调的 Llama2 和 Mistral 模型中的 [INST] 标签。这些模型在使用这些标签进行预训练,在推理中使用它们有助于确保最佳性能。

  2. 任务特定的模板,用于让模型适应训练后预期处理的特定任务。示例包括语法纠正 (GrammarErrorCorrectionTemplate)、摘要生成 (SummarizeTemplate)、问答 (QuestionAnswerTemplate) 等。

  3. 社区标准化模板,例如 ChatMLTemplate

例如,如果我想微调一个模型来执行语法纠正任务,我可以使用 GrammarErrorCorrectionTemplate 将文本“Correct this to standard English: {prompt} — Corrected: {response}” 添加到我所有的数据样本中。

from torchtune.data import GrammarErrorCorrectionTemplate, Message

sample = {
    "incorrect": "This are a cat",
    "correct": "This is a cat.",
}
msgs = [
    Message(role="user", content=sample["incorrect"]),
    Message(role="assistant", content=sample["correct"]),
]

gec_template = GrammarErrorCorrectionTemplate()
templated_msgs = gec_template(msgs)
for msg in templated_msgs:
    print(msg.text_content)
# Correct this to standard English: This are a cat
# ---
# Corrected:
# This is a cat.

添加的文本与模型分词器添加的特殊 token 不同。关于提示模板和特殊 token 之间区别的更详细讨论,请参阅 Tokenizing prompt templates & special tokens

使用提示模板

提示模板被传递给分词器,并将自动应用于你正在微调的数据集。你可以通过两种方式传递它:

  • 一个指向提示模板类的字符串点路径,例如:“torchtune.models.mistral.MistralChatTemplate” 或 “path.to.my.CustomPromptTemplate”

  • 一个字典,将角色映射到字符串元组,表示在消息内容之前和之后添加的文本

通过点路径字符串定义

# In code
from torchtune.models.mistral import mistral_tokenizer

m_tokenizer = mistral_tokenizer(
    path="/tmp/Mistral-7B-v0.1/tokenizer.model"
    prompt_template="torchtune.models.mistral.MistralChatTemplate"
)
# In config
tokenizer:
  _component_: torchtune.models.mistral.mistral_tokenizer
  path: /tmp/Mistral-7B-v0.1/tokenizer.model
  prompt_template: torchtune.models.mistral.MistralChatTemplate

通过字典定义

例如,要实现以下提示模板:

System: {content}\\n
User: {content}\\n
Assistant: {content}\\n
Tool: {content}\\n

你需要为每个角色传入一个元组,其中 PREPEND_TAG 是添加到文本内容之前的字符串,APPEND_TAG 是添加到之后的字符串。

template = {role: (PREPEND_TAG, APPEND_TAG)}

因此,模板可以定义如下:

template = {
    "system": ("System: ", "\\n"),
    "user": ("User: ", "\\n"),
    "assistant": ("Assistant: ", "\\n"),
    "ipython": ("Tool: ", "\\n"),
}

现在我们可以将其作为字典传递给分词器:

# In code
from torchtune.models.mistral import mistral_tokenizer

template = {
    "system": ("System: ", "\\n"),
    "user": ("User: ", "\\n"),
    "assistant": ("Assistant: ", "\\n"),
    "ipython": ("Tool: ", "\\n"),
}
m_tokenizer = mistral_tokenizer(
    path="/tmp/Mistral-7B-v0.1/tokenizer.model"
    prompt_template=template,
)
# In config
tokenizer:
  _component_: torchtune.models.mistral.mistral_tokenizer
  path: /tmp/Mistral-7B-v0.1/tokenizer.model
  prompt_template:
    system:
      - "System: "
      - "\\n"
    user:
      - "User: "
      - "\\n"
    assistant:
      - "Assistant: "
      - "\\n"
    ipython:
      - "Tool: "
      - "\\n"

如果你不想为某个角色添加前置/后置标签,你可以在需要的地方传入空字符串 “” 。

使用 PromptTemplate

模板字典也可以传递给 PromptTemplate,这样你就可以将其作为一个独立的自定义提示模板类使用。

from torchtune.data import PromptTemplate

def my_custom_template() -> PromptTemplate:
    return PromptTemplate(
        template={
            "user": ("User: ", "\\n"),
            "assistant": ("Assistant: ", "\\n"),
        },
    )

template = my_custom_template()
msgs = [
    Message(role="user", content="Hello world!"),
    Message(role="assistant", content="Is AI overhyped?"),
]
templated_msgs = template(msgs)
for msg in templated_msgs:
    print(msg.role, msg.text_content)
# user, User: Hello world!
#
# assistant, Assistant: Is AI overhyped?
#

自定义提示模板

对于不完全符合 PREPEND_TAG content APPEND_TAG 模式的更高级配置,你可以创建一个继承自 PromptTemplateInterface 并实现 __call__ 方法的新类。

from torchtune.data import Message

class PromptTemplateInterface(Protocol):
    def __call__(
        self,
        messages: List[Message],
        inference: bool = False,
    ) -> List[Message]:
        """
        Format each role's message(s) according to the prompt template

        Args:
            messages (List[Message]): a single conversation, structured as a list
                of :class:`~torchtune.data.Message` objects
            inference (bool): Whether the template is being used for inference or not.

        Returns:
            The formatted list of messages
        """
        pass

# Contrived example - make all assistant prompts say "Eureka!"
class EurekaTemplate(PromptTemplateInterface):
    def __call__(
        self,
        messages: List[Message],
        inference: bool = False,
    ) -> List[Message]:
        formatted_dialogue = []
        for message in messages:
            if message.role == "assistant":
                content = "Eureka!"
            else:
                content = message.content
            formatted_dialogue.append(
                Message(
                    role=message.role,
                    content=content,
                    masked=message.masked,
                    ipython=message.ipython,
                    eot=message.eot,
                ),
            )
        return formatted_dialogue

template = EurekaTemplate()
msgs = [
    Message(role="user", content="Hello world!"),
    Message(role="assistant", content="Is AI overhyped?"),
]
templated_msgs = template(msgs)
for msg in templated_msgs:
    print(msg.role, msg.text_content)
# user, Hello world!
# assistant, Eureka!

更多示例,你可以查看 MistralChatTemplateLlama2ChatTemplate

要在分词器中使用此自定义模板,你可以通过点路径字符串传递它:

# In code
from torchtune.models.mistral import mistral_tokenizer

m_tokenizer = mistral_tokenizer(
    path="/tmp/Mistral-7B-v0.1/tokenizer.model",
    prompt_template="path.to.template.EurekaTemplate",
)
# In config
tokenizer:
  _component_: torchtune.models.mistral.mistral_tokenizer
  path: /tmp/Mistral-7B-v0.1/tokenizer.model
  prompt_template: path.to.template.EurekaTemplate

内置提示模板

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取面向初学者和高级开发者的深度教程

查看教程

资源

查找开发资源并获得问题解答

查看资源