ModelType¶
- class torchtune.training.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[source]¶
ModelType 由检查点工具用于区分不同的模型架构。
如果你要添加一个与仓库中现有模型格式不同的新模型,可以添加一个新的 ModelType 来控制该模型独有的权重转换逻辑。
- 变量:
LLAMA3_2 (str) – Llama3.2 系列模型。参见
llama3_2()
LLAMA3_VISION (str) – LLama3 视觉系列模型。参见
llama3_2_vision_decoder()
PHI4 (str) – Phi-4 系列模型。参见
phi4()
REWARD (str) – 带有分类头,用于奖励建模的 Llama2、Llama3 或 Mistral 模型,分类头投影到一个类别。参见
mistral_reward_7b()
或llama2_reward_7b()
CLIP_TEXT (str) – CLIP 文本编码器。参见
clip_text_encoder_large()
T5_ENCODER (str) – T5 文本编码器。参见
t5_v1_1_xxl_encoder()
示例
>>> # Usage in a checkpointer class >>> def load_checkpoint(self, ...): >>> ... >>> if self._model_type == MY_NEW_MODEL: >>> state_dict = my_custom_state_dict_mapping(state_dict)