华为mindspore-如何训练一个gpt一样的文本生成模型
CausalLanguageModelingTrainer Task For Trainer.
Args:
model_name (str): The model name of Task-Trainer. Default: None
Examples:
>>> from mindformers import CausalLanguageModelingTrainer
>>> gen_trainer = CausalLanguageModelingTrainer(model_name="gpt2")
>>> gen_trainer.train()
>>> res = gen_trainer.predict(input_data = "hello world [MASK]")
Raises:
NotImplementedError: If train method or evaluate method or predict method not implemented.
初始化
def __init__(self, model_name: str = None):
super(CausalLanguageModelingTrainer, self).__init__("text_generation", model_name)
很简单的代码,用的是父类BaseTrainer,并传入两个参数:"text_generation", model_name
训练
def train(self,
config: Optional[Union[dict, MindFormerConfig, ConfigArguments, TrainingArguments]] = None,
network: Optional[Union[Cell, BaseModel]] = None,
dataset: Optional[Union[BaseDataset, GeneratorDataset]] = None,
wrapper: Optional[TrainOneStepCell] = None,
optimizer: Optional[Optimizer] = None,
callbacks: Optional[Union[Callback, List[Callback]]] = None,
**kwargs):
r"""Train task for CausalLanguageModeling Trainer.
This function is used to train or fine-tune the network.
"""
self.training_process(
config=config,
network=network,
callbacks=callbacks,
dataset=dataset,
wrapper=wrapper,
optimizer=optimizer,
**kwargs)
调用的basetrianer(mindformers/trainer/base_trainer.py · MindSpore/mindformers - Gitee.com) 的training_process方法。该方法用于训练或微调MindFormers中的模型。它需要几个参数,包括配置、网络、数据集、优化器、包装器和回调。
training_process方法首先设置配置参数,然后构建数据集,构建网络,并设置模型包装器。然后,它构建优化器并创建用于在训练期间进行评估的计算度量。该函数初始化模型并开始训练,同时定期进行日志记录。如果需要,可以从检查点恢复培训过程。最后,当训练完成时,函数会记录日志。