ModelType¶
- class torchtune.training.ModelType(value, names=None, *, module=None, qualname=None, type=None, start=1, boundary=None)[源代码]¶
ModelType 由检查点器使用,以区分不同的模型架构。
如果您正在添加遵循与仓库中现有模型不同格式的新模型,您可以添加新的 ModelType 以控制特定于该模型的权重转换逻辑。
- 变量:
GEMMA (str) – Gemma 模型系列。 请参阅
gemma()
GEMMA2 (str) – Gemma 2 模型系列。 请参阅
gemma2()
LLAMA2 (str) – Llama2 模型系列。 请参阅
llama2()
LLAMA3 (str) – Llama3 模型系列。 请参阅
llama3()
LLAMA3_2 (str) – Llama3.2 模型系列。 请参阅
llama3_2()
LLAMA3_VISION (str) – LLama3 vision 模型系列。 请参阅
llama3_2_vision_decoder()
MISTRAL (str) – Mistral 模型系列。 请参阅
mistral()
PHI3_MINI (str) – Phi-3 模型系列。 请参阅
phi3()
REWARD (str) – 具有分类头的 Llama2、Llama3 或 Mistral 模型,该分类头投影到单个类别以进行奖励建模。 请参阅
mistral_reward_7b()
或llama2_reward_7b()
QWEN2 (str) – Qwen2 模型系列。 请参阅
qwen2()
CLIP_TEXT (str) – CLIP 文本编码器。 请参阅
clip_text_encoder_large()
示例
>>> # Usage in a checkpointer class >>> def load_checkpoint(self, ...): >>> ... >>> if self._model_type == MY_NEW_MODEL: >>> state_dict = my_custom_state_dict_mapping(state_dict)