你还弄不清xxxForCausalLM和xxxForConditionalGeneration吗?
阅读原文时间:2023年07月16日阅读:1

Part1基本介绍

大语言模型目前一发不可收拾,在使用的时候经常会看到transformers库的踪影,其中xxxCausalLM和xxxForConditionalGeneration会经常出现在我们的视野中,接下来我们就来聊聊transformers库中的一些基本任务。

这里以三类模型为例:bert(自编码)、gpt(自回归)、bart(编码-解码)

首先我们整体看下每个模型有什么任务:

from ..bart.modeling_bart import (    BartForCausalLM,    BartForConditionalGeneration,    BartForQuestionAnswering,    BartForSequenceClassification,    BartModel,)from ..bert.modeling_bert import (    BertForMaskedLM,    BertForMultipleChoice,    BertForNextSentencePrediction,    BertForPreTraining,    BertForQuestionAnswering,    BertForSequenceClassification,    BertForTokenClassification,    BertLMHeadModel,    BertModel,)from ..gpt2.modeling_gpt2 import GPT2ForSequenceClassification, GPT2LMHeadModel, GPT2Model
  • BertModel(BertPreTrainedModel):最原始的bert,可获得句向量表示或者每个token的向量表示。

  • BertForPreTraining(BertPreTrainedModel):在BertModel的基础上加了一个预训练头:

    self.bert = BertModel(config)self.cls = BertPreTrainingHeads(config)class BertPreTrainingHeads(nn.Module):    def init(self, config):        super().init()        self.predictions = BertLMPredictionHead(config)        self.seq_relationship = nn.Linear(config.hidden_size, 2)    def forward(self, sequence_output, pooled_output):        prediction_scores = self.predictions(sequence_output)        seq_relationship_score = self.seq_relationship(pooled_output)        return prediction_scores, seq_relationship_scoreclass BertLMPredictionHead(nn.Module):    def init(self, config):        super().init()        self.transform = BertPredictionHeadTransform(config)        # The output weights are the same as the input embeddings, but there is        # an output-only bias for each token.        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)        self.bias = nn.Parameter(torch.zeros(config.vocab_size))        # Need a link between the two variables so that the bias is correctly resized with resize_token_embeddings        self.decoder.bias = self.bias    def forward(self, hidden_states):        hidden_states = self.transform(hidden_states)        hidden_states = self.decoder(hidden_states)        return hidden_states

    对应bert的两个训练任务:掩码语言模型(MLM)和下一个句子预测(NSP)。

  • BertLMHeadModel(BertPreTrainedModel):MLM任务

    self.bert = BertModel(config, add_pooling_layer=False)self.cls = BertOnlyMLMHead(config)class BertOnlyMLMHead(nn.Module):    def init(self, config):        super().init()        self.predictions = BertLMPredictionHead(config)    def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:        prediction_scores = self.predictions(sequence_output)        return prediction_scoresclass BertLMPredictionHead(nn.Module):    def init(self, config):        super().init()        self.transform = BertPredictionHeadTransform(config)        # The output weights are the same as the input embeddings, but there is        # an output-only bias for each token.        self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)        self.bias = nn.Parameter(torch.zeros(config.vocab_size))        # Need a link between the two variables so that the bias is correctly resized with resize_token_embeddings        self.decoder.bias = self.bias    def forward(self, hidden_states):        hidden_states = self.transform(hidden_states)        hidden_states = self.decoder(hidden_states)        return hidden_states

  • BertForNextSentencePrediction(BertPreTrainedModel):NSP任务

    self.bert = BertModel(config)self.cls = BertOnlyNSPHead(config)class BertOnlyNSPHead(nn.Module):    def init(self, config):        super().init()        self.seq_relationship = nn.Linear(config.hidden_size, 2)    def forward(self, pooled_output):        seq_relationship_score = self.seq_relationship(pooled_output)        return seq_relationship_score

  • BertForSequenceClassification(BertPreTrainedModel):对句子进行分类

    self.bert = BertModel(config)classifier_dropout = (    config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, config.num_labels)

  • BertForMultipleChoice(BertPreTrainedModel)::多项选择

    self.bert = BertModel(config)classifier_dropout = (        config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob    )self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, 1)

  • BertForTokenClassification(BertPreTrainedModel):对token进行分类,一般为命名实体识别任务

    self.bert = BertModel(config, add_pooling_layer=False)classifier_dropout = (    config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)self.dropout = nn.Dropout(classifier_dropout)self.classifier = nn.Linear(config.hidden_size, config.num_labels)

  • BertForQuestionAnswering(BertPreTrainedModel):QA任务,很多任务都可以转换为这种形式。即识别答案的开始位置和结束位置。

    self.bert = BertModel(config, add_pooling_layer=False)self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)

  • GPT2Model(GPT2PreTrainedModel):原始的GPT2模型,返回每个token的向量。

  • GPT2LMHeadModel(GPT2PreTrainedModel):进行语言模型任务。判断每一个token的下一个token是什么、

    self.transformer = GPT2Model(config)self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

  • GPT2DoubleHeadsModel(GPT2PreTrainedModel):除了语言模型任务外,额外定义了一个任务:多项选择任务。比如一个问题有两个回答,一个正确回答,一个错误回答,进行二分类任务判断哪一个是正确回答。当然可以扩展到多个选项。

     config.num_labels = 1 self.transformer = GPT2Model(config) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.multiple_choice_head = SequenceSummary(config)

这个要看个例子:

import&nbsp;torchfrom&nbsp;transformers&nbsp;import&nbsp;GPT2Tokenizer,&nbsp;GPT2DoubleHeadsModeltokenizer&nbsp;=&nbsp;GPT2Tokenizer.from_pretrained('gpt2')model&nbsp;=&nbsp;GPT2DoubleHeadsModel.from_pretrained('gpt2')choices&nbsp;=&nbsp;[&nbsp;"Bob&nbsp;likes&nbsp;candy&nbsp;;&nbsp;what&nbsp;does&nbsp;Bob&nbsp;like&nbsp;?&nbsp;&nbsp;Bag&nbsp;<|endoftext|>",&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"Bob&nbsp;likes&nbsp;candy&nbsp;;&nbsp;what&nbsp;does&nbsp;Bob&nbsp;like&nbsp;?&nbsp;&nbsp;Burger&nbsp;<|endoftext|>",&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"Bob&nbsp;likes&nbsp;candy&nbsp;;&nbsp;what&nbsp;does&nbsp;Bob&nbsp;like&nbsp;?&nbsp;&nbsp;Candy&nbsp;<|endoftext|>",&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"Bob&nbsp;likes&nbsp;candy&nbsp;;&nbsp;what&nbsp;does&nbsp;Bob&nbsp;like&nbsp;?&nbsp;&nbsp;Apple&nbsp;<|endoftext|>"]encoded_choices&nbsp;=&nbsp;[tokenizer.encode(s)&nbsp;for&nbsp;s&nbsp;in&nbsp;choices]eos_token_location&nbsp;=&nbsp;[tokens.index(tokenizer.eos_token_id)&nbsp;for&nbsp;tokens&nbsp;in&nbsp;encoded_choices]input_ids&nbsp;=&nbsp;torch.tensor(encoded_choices).unsqueeze(0)&nbsp;mc_token_ids&nbsp;=&nbsp;torch.tensor([eos_token_location])&nbsp;print(input_ids.shape)print(mc_token_ids.shape)outputs&nbsp;=&nbsp;model(input_ids,&nbsp;mc_token_ids=mc_token_ids)lm_prediction_scores,&nbsp;mc_prediction_scores&nbsp;=&nbsp;outputs[:2]print(lm_prediction_scores.shape)print(mc_prediction_scores)"""torch.Size([1,&nbsp;4,&nbsp;13])torch.Size([1,&nbsp;4])torch.Size([1,&nbsp;4,&nbsp;13,&nbsp;50257])tensor([[-6.0075,&nbsp;-6.0649,&nbsp;-6.0657,&nbsp;-6.0585]],&nbsp;grad_fn=<SqueezeBackward1>)"""

Confused by GPT2DoubleHeadsModel example · Issue #1794 · huggingface/transformers (github.com)

How to use GPT2DoubleHeadsModel? · Issue #3680 · huggingface/transformers (github.com)

  • GPT2ForSequenceClassification(GPT2PreTrainedModel):显然,针对于文本分类任务

    self.transformer = GPT2Model(config)self.score = nn.Linear(config.n_embd, self.num_labels, bias=False)

  • GPT2ForTokenClassification(GPT2PreTrainedModel):针对于token分类(命名实体识别任务)

    self.transformer = GPT2Model(config)if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:    classifier_dropout = config.classifier_dropoutelif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:    classifier_dropout = config.hidden_dropoutelse:    classifier_dropout = 0.1    self.dropout = nn.Dropout(classifier_dropout)    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

举个例子:

import&nbsp;torchfrom&nbsp;transformers&nbsp;import&nbsp;GPT2Tokenizer,&nbsp;GPT2DoubleHeadsModel,&nbsp;GPT2Modeltokenizer&nbsp;=&nbsp;GPT2Tokenizer.from_pretrained('gpt2')model&nbsp;=&nbsp;GPT2Model.from_pretrained('gpt2')text&nbsp;=&nbsp;[&nbsp;&nbsp;"Bob&nbsp;likes&nbsp;candy&nbsp;;&nbsp;what&nbsp;does&nbsp;Bob&nbsp;like&nbsp;?&nbsp;&nbsp;Bag&nbsp;<|endoftext|>",&nbsp;&nbsp;&nbsp;"Bob&nbsp;likes&nbsp;candy&nbsp;;&nbsp;what&nbsp;does&nbsp;Bob&nbsp;like&nbsp;?&nbsp;&nbsp;Bag&nbsp;<|endoftext|>"]inputs&nbsp;=&nbsp;tokenizer(text,&nbsp;return_tensors="pt")print(inputs)print(tokenizer.decode(inputs["input_ids"][0]))output&nbsp;=&nbsp;model(**inputs)print(output[0].shape)"""{'input_ids':&nbsp;tensor([[18861,&nbsp;&nbsp;7832,&nbsp;18550,&nbsp;&nbsp;2162,&nbsp;&nbsp;&nbsp;644,&nbsp;&nbsp;&nbsp;857,&nbsp;&nbsp;5811,&nbsp;&nbsp;&nbsp;588,&nbsp;&nbsp;5633,&nbsp;&nbsp;&nbsp;220,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;20127,&nbsp;&nbsp;&nbsp;220,&nbsp;50256],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;[18861,&nbsp;&nbsp;7832,&nbsp;18550,&nbsp;&nbsp;2162,&nbsp;&nbsp;&nbsp;644,&nbsp;&nbsp;&nbsp;857,&nbsp;&nbsp;5811,&nbsp;&nbsp;&nbsp;588,&nbsp;&nbsp;5633,&nbsp;&nbsp;&nbsp;220,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;20127,&nbsp;&nbsp;&nbsp;220,&nbsp;50256]]),&nbsp;'attention_mask':&nbsp;tensor([[1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;[1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1,&nbsp;1]])}Bob&nbsp;likes&nbsp;candy&nbsp;;&nbsp;what&nbsp;does&nbsp;Bob&nbsp;like?&nbsp;&nbsp;Bag&nbsp;<|endoftext|>torch.Size([2,&nbsp;13,&nbsp;768])"""
  • BartModel(BartPretrainedModel):bart的原始模型,返回解码器每个token的向量。当然还有其它可选的。

  • BartForConditionalGeneration(BartPretrainedModel):顾名思义,条件文本生成。

    self.model = BartModel(config)self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)

输入一般我们需要定义:input_ids(编码器的输入)、attention_mask (编码器注意力)、decoder_input_ids(解码器的输入),target_attention_mask(解码器注意力)输出一般我们使用的有两个 loss=masked_lm_loss和 logits=lm_logits。

  • BartForSequenceClassification(BartPretrainedModel):文本分类

    self.model = BartModel(config)self.classification_head = BartClassificationHead(    config.d_model,    config.d_model,    config.num_labels,    config.classifier_dropout,)class BartClassificationHead(nn.Module):    """Head for sentence-level classification tasks."""    def init(        self,        input_dim: int,        inner_dim: int,        num_classes: int,        pooler_dropout: float,    ):        super().init()        self.dense = nn.Linear(input_dim, inner_dim)        self.dropout = nn.Dropout(p=pooler_dropout)        self.out_proj = nn.Linear(inner_dim, num_classes)    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:        hidden_states = self.dropout(hidden_states)        hidden_states = self.dense(hidden_states)        hidden_states = torch.tanh(hidden_states)        hidden_states = self.dropout(hidden_states)        hidden_states = self.out_proj(hidden_states)        return hidden_states

具体的获取logits是这么操作的:

hidden_states&nbsp;=&nbsp;outputs[0]&nbsp;&nbsp;#&nbsp;last&nbsp;hidden&nbsp;state#&nbsp;找到eos_mask的位置eos_mask&nbsp;=&nbsp;input_ids.eq(self.config.eos_token_id).to(hidden_states.device)if&nbsp;len(torch.unique_consecutive(eos_mask.sum(1)))&nbsp;>&nbsp;1:&nbsp;&nbsp;&nbsp;&nbsp;raise&nbsp;ValueError("All&nbsp;examples&nbsp;must&nbsp;have&nbsp;the&nbsp;same&nbsp;number&nbsp;of&nbsp;<eos>&nbsp;tokens.")&nbsp;&nbsp;&nbsp;&nbsp;sentence_representation&nbsp;=&nbsp;hidden_states[eos_mask,&nbsp;:].view(hidden_states.size(0),&nbsp;-1,&nbsp;hidden_states.size(-1))[&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;:,&nbsp;-1,&nbsp;:&nbsp;&nbsp;&nbsp;&nbsp;]logits&nbsp;=&nbsp;self.classification_head(sentence_representation)

损失计算:

loss&nbsp;=&nbsp;Noneif&nbsp;labels&nbsp;is&nbsp;not&nbsp;None:&nbsp;&nbsp;&nbsp;&nbsp;labels&nbsp;=&nbsp;labels.to(logits.device)&nbsp;&nbsp;&nbsp;&nbsp;if&nbsp;self.config.problem_type&nbsp;is&nbsp;None:&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;if&nbsp;self.config.num_labels&nbsp;==&nbsp;1:&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.config.problem_type&nbsp;=&nbsp;"regression"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;elif&nbsp;self.config.num_labels&nbsp;>&nbsp;1&nbsp;and&nbsp;(labels.dtype&nbsp;==&nbsp;torch.long&nbsp;or&nbsp;labels.dtype&nbsp;==&nbsp;torch.int):&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.config.problem_type&nbsp;=&nbsp;"single_label_classification"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;else:&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;self.config.problem_type&nbsp;=&nbsp;"multi_label_classification"&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;if&nbsp;self.config.problem_type&nbsp;==&nbsp;"regression":&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss_fct&nbsp;=&nbsp;MSELoss()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;if&nbsp;self.config.num_labels&nbsp;==&nbsp;1:&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss&nbsp;=&nbsp;loss_fct(logits.squeeze(),&nbsp;labels.squeeze())&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;else:&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss&nbsp;=&nbsp;loss_fct(logits,&nbsp;labels)&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;elif&nbsp;self.config.problem_type&nbsp;==&nbsp;"single_label_classification":&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss_fct&nbsp;=&nbsp;CrossEntropyLoss()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss&nbsp;=&nbsp;loss_fct(logits.view(-1,&nbsp;self.config.num_labels),&nbsp;labels.view(-1))&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;elif&nbsp;self.config.problem_type&nbsp;==&nbsp;"multi_label_classification":&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss_fct&nbsp;=&nbsp;BCEWithLogitsLoss()
  • BartForQuestionAnswering(BartPretrainedModel):问答和之前GPT基本一致,只不过这里的输入到计算logits前的向量是解码器的隐含层向量。

    config.num_labels = 2self.num_labels = config.num_labelsself.model = BartModel(config)self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)sequence_output = outputs[0]logits = self.qa_outputs(sequence_output)start_logits, end_logits = logits.split(1, dim=-1)start_logits = start_logits.squeeze(-1).contiguous()end_logits = end_logits.squeeze(-1).contiguous()

  • BartForCausalLM(BartPretrainedModel):语言模型任务,只使用bart的解码器。

    config = copy.deepcopy(config)config.is_decoder = Trueconfig.is_encoder_decoder = Falsesuper().init(config)self.model = BartDecoderWrapper(config)self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)outputs = self.model.decoder(            input_ids=input_ids,            attention_mask=attention_mask,            encoder_hidden_states=encoder_hidden_states,            encoder_attention_mask=encoder_attention_mask,            head_mask=head_mask,            cross_attn_head_mask=cross_attn_head_mask,            past_key_values=past_key_values,            inputs_embeds=inputs_embeds,            use_cache=use_cache,            output_attentions=output_attentions,            output_hidden_states=output_hidden_states,            return_dict=return_dict,)logits = self.lm_head(outputs[0])>>> from transformers import AutoTokenizer, BartForCausalLM>>> tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")>>> model = BartForCausalLM.from_pretrained("facebook/bart-base", add_cross_attention=False)>>> assert model.config.is_decoder, f"{model.class} has to be configured as a decoder.">>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")>>> outputs = model(**inputs)>>> logits = outputs.logits>>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]>>> list(logits.shape) == expected_shapeTrue


Part2实操

接下来针对xxxCausalLM和xxxForConditionalGeneration,我们实际操作来更加深入的了解它们。首先需要安装一些依赖:

pip&nbsp;install&nbsp;transformers==4.28.1pip&nbsp;install&nbsp;evaluatepip&nbsp;install&nbsp;datasets

数据从这里下载:https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv

直接上代码:

import&nbsp;torchfrom&nbsp;tqdm&nbsp;import&nbsp;tqdmfrom&nbsp;datasets&nbsp;import&nbsp;load_datasetfrom&nbsp;transformers&nbsp;import&nbsp;AutoTokenizer,&nbsp;AutoModelForCausalLMfrom&nbsp;transformers&nbsp;import&nbsp;(&nbsp;&nbsp;&nbsp;&nbsp;default_data_collator,&nbsp;&nbsp;&nbsp;&nbsp;get_linear_schedule_with_warmup,)from&nbsp;torch.utils.data&nbsp;import&nbsp;DataLoaderdata_file&nbsp;=&nbsp;"./ChnSentiCorp_htl_all.csv"&nbsp;&nbsp;#&nbsp;数据文件路径,数据需要提前下载max_length&nbsp;=&nbsp;86train_batch_size&nbsp;=&nbsp;64eval_batch_size&nbsp;=&nbsp;64num_epochs&nbsp;=&nbsp;10lr&nbsp;=&nbsp;3e-4#&nbsp;加载数据集dataset&nbsp;=&nbsp;load_dataset("csv",&nbsp;data_files=data_file)dataset&nbsp;=&nbsp;dataset.filter(lambda&nbsp;x:&nbsp;x["review"]&nbsp;is&nbsp;not&nbsp;None)dataset&nbsp;=&nbsp;dataset["train"].train_test_split(0.2,&nbsp;seed=123)model_name_or_path&nbsp;=&nbsp;"uer/gpt2-chinese-cluecorpussmall"tokenizer&nbsp;=&nbsp;AutoTokenizer.from_pretrained(model_name_or_path)model&nbsp;=&nbsp;AutoModelForCausalLM.from_pretrained(model_name_or_path)#&nbsp;example&nbsp;=&nbsp;{'label':&nbsp;1,&nbsp;'review':&nbsp;'早餐太差,无论去多少人,那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。'}def&nbsp;process(example):&nbsp;&nbsp;&nbsp;&nbsp;text&nbsp;=&nbsp;example["review"]&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;text&nbsp;=&nbsp;["399真的很值得之前也住过别的差不多价位的酒店式公寓没有这间好厨房很像厨房很大整个格局也都很舒服早上的早餐我订的8点半的已经冷了。。。位置啊什么还是很好的下次还会去服务也很周到"]&nbsp;&nbsp;&nbsp;&nbsp;batch_size&nbsp;=&nbsp;len(text)&nbsp;&nbsp;&nbsp;&nbsp;inputs&nbsp;=&nbsp;tokenizer(text,&nbsp;add_special_tokens=False,&nbsp;truncation=True,&nbsp;max_length=max_length)&nbsp;&nbsp;&nbsp;&nbsp;inputs["labels"]&nbsp;=&nbsp;[]&nbsp;&nbsp;&nbsp;&nbsp;for&nbsp;i&nbsp;in&nbsp;range(batch_size):&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;input_ids&nbsp;=&nbsp;inputs["input_ids"][i]&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;if&nbsp;len(input_ids)&nbsp;+&nbsp;1&nbsp;<=&nbsp;max_length:&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;inputs["input_ids"][i]&nbsp;=&nbsp;input_ids&nbsp;+&nbsp;[tokenizer.pad_token_id]&nbsp;+&nbsp;[0]&nbsp;*&nbsp;(max_length&nbsp;-&nbsp;len(input_ids)&nbsp;-&nbsp;1)&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;inputs["labels"].append(input_ids&nbsp;+&nbsp;[tokenizer.pad_token_id]&nbsp;+&nbsp;[-100]&nbsp;*&nbsp;(max_length&nbsp;-&nbsp;len(input_ids)&nbsp;-&nbsp;1))&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;inputs["attention_mask"][i]&nbsp;=&nbsp;[1]&nbsp;*&nbsp;len(input_ids)&nbsp;+&nbsp;[0]&nbsp;+&nbsp;[0]&nbsp;*&nbsp;(max_length&nbsp;-&nbsp;len(input_ids)&nbsp;-&nbsp;1)&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;else:&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;inputs["input_ids"][i]&nbsp;=&nbsp;input_ids[:max_length&nbsp;-&nbsp;1]&nbsp;+&nbsp;[tokenizer.pad_token_id]&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;inputs["labels"].append(inputs["input_ids"][i])&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;inputs["attention_mask"][i]&nbsp;=&nbsp;[1]&nbsp;*&nbsp;max_length&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;inputs["token_type_ids"][i]&nbsp;=&nbsp;[0]&nbsp;*&nbsp;max_length&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;for&nbsp;k,&nbsp;v&nbsp;in&nbsp;inputs.items():&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;&nbsp;&nbsp;print(k,&nbsp;len(v[0]))&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;#&nbsp;assert&nbsp;len(inputs["labels"][i])&nbsp;==&nbsp;len(inputs["input_ids"][i])&nbsp;==&nbsp;len(inputs["token_type_ids"][i])&nbsp;==&nbsp;len(inputs["attention_mask"][i])&nbsp;==&nbsp;86&nbsp;&nbsp;&nbsp;&nbsp;return&nbsp;inputs#&nbsp;process(None)train_dataset&nbsp;=&nbsp;dataset["train"].map(process,&nbsp;batched=True,&nbsp;num_proc=1,&nbsp;remove_columns=dataset["train"].column_names)test_dataset&nbsp;=&nbsp;dataset["test"].map(process,&nbsp;batched=True,&nbsp;num_proc=1,&nbsp;remove_columns=dataset["test"].column_names)train_dataloader&nbsp;=&nbsp;DataLoader(&nbsp;&nbsp;&nbsp;&nbsp;train_dataset,&nbsp;collate_fn=default_data_collator,&nbsp;shuffle=True,&nbsp;batch_size=train_batch_size,&nbsp;pin_memory=True)test_dataloader&nbsp;=&nbsp;DataLoader(&nbsp;&nbsp;&nbsp;&nbsp;test_dataset,&nbsp;collate_fn=default_data_collator,&nbsp;batch_size=eval_batch_size,&nbsp;pin_memory=True)#&nbsp;optimizeroptimizer&nbsp;=&nbsp;torch.optim.AdamW(model.parameters(),&nbsp;lr=lr)#&nbsp;lr&nbsp;schedulerlr_scheduler&nbsp;=&nbsp;get_linear_schedule_with_warmup(&nbsp;&nbsp;&nbsp;&nbsp;optimizer=optimizer,&nbsp;&nbsp;&nbsp;&nbsp;num_warmup_steps=0,&nbsp;&nbsp;&nbsp;&nbsp;num_training_steps=(len(train_dataloader)&nbsp;*&nbsp;num_epochs),)model.cuda()from&nbsp;tqdm&nbsp;import&nbsp;tqdmfor&nbsp;epoch&nbsp;in&nbsp;range(num_epochs):&nbsp;&nbsp;&nbsp;&nbsp;model.train()&nbsp;&nbsp;&nbsp;&nbsp;total_loss&nbsp;=&nbsp;0&nbsp;&nbsp;&nbsp;&nbsp;t&nbsp;=&nbsp;tqdm(train_dataloader)&nbsp;&nbsp;&nbsp;&nbsp;for&nbsp;step,&nbsp;batch&nbsp;in&nbsp;enumerate(t):&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;for&nbsp;k,&nbsp;v&nbsp;in&nbsp;batch.items():&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;batch[k]&nbsp;=&nbsp;v.cuda()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;outputs&nbsp;=&nbsp;model(&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;input_ids=batch["input_ids"],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;token_type_ids=batch["token_type_ids"],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;attention_mask=batch["attention_mask"],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;labels=batch["labels"],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;)&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss&nbsp;=&nbsp;outputs.loss&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;t.set_description("loss:{:.6f}".format(loss.item()))&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;total_loss&nbsp;+=&nbsp;loss.detach().float()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss.backward()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.step()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;lr_scheduler.step()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.zero_grad()&nbsp;&nbsp;&nbsp;&nbsp;train_epoch_loss&nbsp;=&nbsp;total_loss&nbsp;/&nbsp;len(train_dataloader)&nbsp;&nbsp;&nbsp;&nbsp;model.save_pretrained("gpt2-chinese/")&nbsp;&nbsp;&nbsp;&nbsp;print(f"epoch:{epoch+1}/{num_epochs}&nbsp;loss:{train_epoch_loss}")

训练结果:

loss:2.416899:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:51<00:00,&nbsp;&nbsp;1.14s/it]epoch:1/10&nbsp;loss:2.7781832218170166loss:2.174688:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:54<00:00,&nbsp;&nbsp;1.17s/it]epoch:2/10&nbsp;loss:2.3192219734191895loss:2.123909:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:55<00:00,&nbsp;&nbsp;1.17s/it]epoch:3/10&nbsp;loss:2.037835121154785loss:1.785878:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:55<00:00,&nbsp;&nbsp;1.18s/it]epoch:4/10&nbsp;loss:1.7687807083129883loss:1.466153:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:55<00:00,&nbsp;&nbsp;1.18s/it]epoch:5/10&nbsp;loss:1.524872064590454loss:1.465316:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:54<00:00,&nbsp;&nbsp;1.17s/it]epoch:6/10&nbsp;loss:1.3074666261672974loss:1.150320:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:54<00:00,&nbsp;&nbsp;1.16s/it]epoch:7/10&nbsp;loss:1.1217808723449707loss:1.043044:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:53<00:00,&nbsp;&nbsp;1.16s/it]epoch:8/10&nbsp;loss:0.9760875105857849loss:0.790678:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:53<00:00,&nbsp;&nbsp;1.16s/it]epoch:9/10&nbsp;loss:0.8597695827484131loss:0.879025:&nbsp;100%|██████████|&nbsp;98/98&nbsp;[01:53<00:00,&nbsp;&nbsp;1.16s/it]epoch:10/10&nbsp;loss:0.790839433670044

可以这么进行预测:

from&nbsp;transformers&nbsp;import&nbsp;AutoTokenizer,&nbsp;GPT2LMHeadModel,&nbsp;TextGenerationPipeline,&nbsp;AutoModelForCausalLMfrom&nbsp;datasets&nbsp;import&nbsp;load_datasetdata_file&nbsp;=&nbsp;"./ChnSentiCorp_htl_all.csv"&nbsp;#&nbsp;数据文件路径,数据需要提前下载dataset&nbsp;=&nbsp;load_dataset("csv",&nbsp;data_files=data_file)dataset&nbsp;=&nbsp;dataset.filter(lambda&nbsp;x:&nbsp;x["review"]&nbsp;is&nbsp;not&nbsp;None)dataset&nbsp;=&nbsp;dataset["train"].train_test_split(0.2,&nbsp;seed=123)model_name_or_path&nbsp;=&nbsp;"uer/gpt2-chinese-cluecorpussmall"tokenizer&nbsp;=&nbsp;AutoTokenizer.from_pretrained(model_name_or_path)model&nbsp;=&nbsp;AutoModelForCausalLM.from_pretrained("./gpt2-chinese/")text_generator&nbsp;=&nbsp;TextGenerationPipeline(model,&nbsp;tokenizer)&nbsp;&nbsp;import&nbsp;randomexamples&nbsp;=&nbsp;dataset["train"]example&nbsp;=&nbsp;random.choice(examples)text&nbsp;=&nbsp;example["review"]print(text)print(text[:10])text_generator(text[:10],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;max_length=100,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;do_sample=False,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;top_p=0.8,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;repetition_penalty=10.0,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;temperature=0.95,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;eos_token_id=0,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;)"""第一次住在这里儿,我针对大家的意见,特别关注了一下,感觉如下吧!1、标准间虽然有点旧但很干净,被子盖得很舒服,也很暖和,卫生间也蛮大的,因是在商业中心离很多还算很近。2、酒店服务还算可以,没有像这里说的那样,入住时,退房时也挺快的,总的来说我很满意。3、早餐也还可以,环境也不错,有点江南的感觉;菜品种品也不少,挺可口。4、可能是在市或者离火车站的距离很近,稍微有点“热闹”,来找我办事的人不方便停车,但还好这里有地下停车场。总体来说,我感觉很不错,值得推荐!!!第一次住在这里儿,我[{'generated_text':&nbsp;'第一次住在这里儿,我&nbsp;感&nbsp;觉&nbsp;很&nbsp;温&nbsp;馨&nbsp;。&nbsp;房&nbsp;间&nbsp;宽&nbsp;敞&nbsp;、&nbsp;干&nbsp;净&nbsp;还&nbsp;有&nbsp;水&nbsp;果&nbsp;送&nbsp;(&nbsp;每&nbsp;人&nbsp;10&nbsp;元&nbsp;)&nbsp;;&nbsp;饭&nbsp;菜&nbsp;也&nbsp;不&nbsp;错&nbsp;!&nbsp;价&nbsp;格&nbsp;合&nbsp;理&nbsp;经&nbsp;济&nbsp;实&nbsp;惠&nbsp;.'}]"""

我们需要注意的几点:

  • 不同模型使用的tokenizer是不一样的,需要注意它们的区别,尤其是pad_token_id和eos_token_id。eos_token_id常常用于标识生成文本的结尾。

  • 有一些中文的生成预训练模型使用的还是Bert的tokenizer,在进行token化的时候,通过指定add_special_tokens=False来避免添加[CLS]和[SEP]。

  • BertTokenizer的eos_token_id为None,这里我们用[PAD]视为生成结束的符号,其索引为0.当然,你也可以设置它为词表里面的特殊符号,比如[SEP]。

  • 对于不需要计算损失的token,我们将其标签设置为-100。

  • 我们的labels和input_ids为什么是一样的,不是说根据上一个词生成下一个词吗?这是因为模型里面帮我们处理了,见代码:

    shift_logits = lm_logits[…, :-1, :].contiguous()shift_labels = labels[…, 1:].contiguous()

  • 进行预测有三种方式,控制文本生成的多样性有很多参数可以选择,具体刚兴趣可参考最后面的链接。

数据从这里下载:https://www.modelscope.cn/datasets/minisnow/couplet_samll.git

直接看代码:

import&nbsp;jsonimport&nbsp;pandas&nbsp;as&nbsp;pdimport&nbsp;numpy&nbsp;as&nbsp;np#&nbsp;import&nbsp;lawrougefrom&nbsp;transformers&nbsp;import&nbsp;BertTokenizer,&nbsp;BartForConditionalGeneration,&nbsp;Text2TextGenerationPipeline,&nbsp;pipelinefrom&nbsp;datasets&nbsp;import&nbsp;load_dataset,&nbsp;Datasetfrom&nbsp;transformers&nbsp;import&nbsp;default_data_collatorimport&nbsp;torchfrom&nbsp;tqdm&nbsp;import&nbsp;tqdmfrom&nbsp;datasets&nbsp;import&nbsp;load_datasetfrom&nbsp;transformers&nbsp;import&nbsp;(&nbsp;&nbsp;&nbsp;&nbsp;default_data_collator,&nbsp;&nbsp;&nbsp;&nbsp;get_linear_schedule_with_warmup,)from&nbsp;torch.utils.data&nbsp;import&nbsp;DataLoader#&nbsp;=============================&nbsp;#&nbsp;加载数据train_path&nbsp;=&nbsp;"couplet_samll/train.csv"train_dataset&nbsp;=&nbsp;Dataset.from_csv(train_path)test_path&nbsp;=&nbsp;"couplet_samll/test.csv"test_dataset&nbsp;=&nbsp;Dataset.from_csv(test_path)max_len&nbsp;=&nbsp;24train_batch_size&nbsp;=&nbsp;64eval_batch_size&nbsp;=&nbsp;64lr&nbsp;=&nbsp;3e-4num_epochs&nbsp;=&nbsp;1#&nbsp;转换为模型需要的格式def&nbsp;tokenize_dataset(tokenizer,&nbsp;dataset,&nbsp;max_len):&nbsp;&nbsp;def&nbsp;convert_to_features(batch):&nbsp;&nbsp;&nbsp;&nbsp;text1&nbsp;=&nbsp;batch["text1"]&nbsp;&nbsp;&nbsp;&nbsp;text2&nbsp;=&nbsp;batch["text2"]&nbsp;&nbsp;&nbsp;&nbsp;inputs&nbsp;=&nbsp;tokenizer.batch_encode_plus(&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;text1,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;max_length=max_len,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;padding="max_length",&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;truncation=True,&nbsp;&nbsp;&nbsp;&nbsp;)&nbsp;&nbsp;&nbsp;&nbsp;targets&nbsp;=&nbsp;tokenizer.batch_encode_plus(&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;text2,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;max_length=max_len,&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;padding="max_length",&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;truncation=True,&nbsp;&nbsp;&nbsp;&nbsp;)&nbsp;&nbsp;&nbsp;&nbsp;outputs&nbsp;=&nbsp;{&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"input_ids":&nbsp;inputs["input_ids"],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"attention_mask":&nbsp;inputs["attention_mask"],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"target_ids":&nbsp;targets["input_ids"],&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;"target_attention_mask":&nbsp;targets["attention_mask"]&nbsp;&nbsp;&nbsp;&nbsp;}&nbsp;&nbsp;&nbsp;&nbsp;return&nbsp;outputs&nbsp;&nbsp;&nbsp;&nbsp;dataset&nbsp;=&nbsp;dataset.map(convert_to_features,&nbsp;batched=True)&nbsp;&nbsp;#&nbsp;Set&nbsp;the&nbsp;tensor&nbsp;type&nbsp;and&nbsp;the&nbsp;columns&nbsp;which&nbsp;the&nbsp;dataset&nbsp;should&nbsp;return&nbsp;&nbsp;columns&nbsp;=&nbsp;['input_ids',&nbsp;'target_ids',&nbsp;'attention_mask',&nbsp;'target_attention_mask']&nbsp;&nbsp;dataset.with_format(type='torch',&nbsp;columns=columns)&nbsp;&nbsp;dataset&nbsp;=&nbsp;dataset.rename_column('target_ids',&nbsp;'labels')&nbsp;&nbsp;dataset&nbsp;=&nbsp;dataset.rename_column('target_attention_mask',&nbsp;'decoder_attention_mask')&nbsp;&nbsp;dataset&nbsp;=&nbsp;dataset.remove_columns(['text1',&nbsp;'text2'])&nbsp;&nbsp;return&nbsp;datasettokenizer&nbsp;=&nbsp;BertTokenizer.from_pretrained("fnlp/bart-base-chinese")train_data&nbsp;=&nbsp;tokenize_dataset(tokenizer,&nbsp;train_dataset,&nbsp;max_len)test_data&nbsp;=&nbsp;tokenize_dataset(tokenizer,&nbsp;test_dataset,&nbsp;max_len)train_dataset&nbsp;=&nbsp;train_datatrain_dataloader&nbsp;=&nbsp;DataLoader(&nbsp;&nbsp;&nbsp;&nbsp;train_dataset,&nbsp;collate_fn=default_data_collator,&nbsp;shuffle=True,&nbsp;batch_size=train_batch_size,&nbsp;pin_memory=True)test_dataset&nbsp;=&nbsp;test_datatest_dataloader&nbsp;=&nbsp;DataLoader(&nbsp;&nbsp;&nbsp;&nbsp;test_dataset,&nbsp;collate_fn=default_data_collator,&nbsp;batch_size=eval_batch_size,&nbsp;pin_memory=True)#&nbsp;optimizermodel&nbsp;=&nbsp;BartForConditionalGeneration.from_pretrained("fnlp/bart-base-chinese")optimizer&nbsp;=&nbsp;torch.optim.AdamW(model.parameters(),&nbsp;lr=lr)#&nbsp;lr&nbsp;schedulerlr_scheduler&nbsp;=&nbsp;get_linear_schedule_with_warmup(&nbsp;&nbsp;&nbsp;&nbsp;optimizer=optimizer,&nbsp;&nbsp;&nbsp;&nbsp;num_warmup_steps=0,&nbsp;&nbsp;&nbsp;&nbsp;num_training_steps=(len(train_dataloader)&nbsp;*&nbsp;num_epochs),)model.cuda()from&nbsp;tqdm&nbsp;import&nbsp;tqdmfor&nbsp;epoch&nbsp;in&nbsp;range(num_epochs):&nbsp;&nbsp;&nbsp;&nbsp;model.train()&nbsp;&nbsp;&nbsp;&nbsp;total_loss&nbsp;=&nbsp;0&nbsp;&nbsp;&nbsp;&nbsp;t&nbsp;=&nbsp;tqdm(train_dataloader)&nbsp;&nbsp;&nbsp;&nbsp;for&nbsp;step,&nbsp;batch&nbsp;in&nbsp;enumerate(t):&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;for&nbsp;k,&nbsp;v&nbsp;in&nbsp;batch.items():&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;batch[k]&nbsp;=&nbsp;v.cuda()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;outputs&nbsp;=&nbsp;model(**batch)&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss&nbsp;=&nbsp;outputs.loss&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;t.set_description("loss:{:.6f}".format(loss.item()))&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;total_loss&nbsp;+=&nbsp;loss.detach().float()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;loss.backward()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.step()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;lr_scheduler.step()&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;optimizer.zero_grad()&nbsp;&nbsp;&nbsp;&nbsp;train_epoch_loss&nbsp;=&nbsp;total_loss&nbsp;/&nbsp;len(train_dataloader)&nbsp;&nbsp;&nbsp;&nbsp;model.save_pretrained("bart-couplet/")&nbsp;&nbsp;&nbsp;&nbsp;tokenizer.save_pretrained("bart-couplet/")&nbsp;&nbsp;&nbsp;&nbsp;print(f"epoch:{epoch+1}/{num_epochs}&nbsp;loss:{train_epoch_loss}")

结果:

loss:1.593506:&nbsp;100%|██████████|&nbsp;4595/4595&nbsp;[33:28<00:00,&nbsp;&nbsp;2.29it/s]epoch:1/1&nbsp;loss:1.76453697681427

我们可以这么预测:

from&nbsp;transformers&nbsp;import&nbsp;Text2TextGenerationPipelinemodel_path&nbsp;=&nbsp;"bart-couplet"#&nbsp;model_path&nbsp;=&nbsp;"fnlp/bart-base-chinese"model&nbsp;=&nbsp;BartForConditionalGeneration.from_pretrained(model_path)tokenizer&nbsp;=&nbsp;BertTokenizer.from_pretrained(model_path)generator&nbsp;=&nbsp;Text2TextGenerationPipeline(model=model,&nbsp;tokenizer=tokenizer)max_len&nbsp;=&nbsp;24test_path&nbsp;=&nbsp;"couplet_samll/test.csv"test_data&nbsp;=&nbsp;pd.read_csv(test_path)texts&nbsp;=&nbsp;test_data["text1"].values.tolist()labels&nbsp;=&nbsp;test_data["text2"].values.tolist()results&nbsp;=&nbsp;generator(texts,&nbsp;max_length=max_len,&nbsp;eos_token_id=102,&nbsp;pad_token_id=0,&nbsp;do_sample=True)for&nbsp;text,&nbsp;label,&nbsp;res&nbsp;in&nbsp;zip(texts,&nbsp;labels,&nbsp;results):&nbsp;&nbsp;print("上联:",&nbsp;text)&nbsp;&nbsp;print("真实下联:",&nbsp;label)&nbsp;&nbsp;print("预测下联:",&nbsp;"".join(res["generated_text"].split("&nbsp;")))&nbsp;&nbsp;print("="*100)&nbsp;&nbsp;&nbsp;&nbsp;"""上联:&nbsp;几帧山水关秋路真实下联:&nbsp;无奈胭脂点绛唇预测下联:&nbsp;天高云淡月光明====================================================================================================上联:&nbsp;许多心事懒收拾真实下联:&nbsp;大好青春莫撂荒预测下联:&nbsp;何妨明月照寒窗====================================================================================================上联:&nbsp;谁同执手人间老真实下联:&nbsp;自愿并肩化外游预测下联:&nbsp;心中有梦月当头====================================================================================================上联:&nbsp;画地为牢封自步真实下联:&nbsp;齐天大圣悟空行预测下联:&nbsp;不妨一世好清闲====================================================================================================上联:&nbsp;布谷携春临五岳真实下联:&nbsp;流莺送喜到千家预测下联:&nbsp;万家灯火庆丰年====================================================================================================上联:&nbsp;冤家宜解不宜结真实下联:&nbsp;穷寇定歼必定追预测下联:&nbsp;不因风雨误春秋====================================================================================================上联:&nbsp;汪伦情义人间少真实下联:&nbsp;法律条文格外繁预测下联:&nbsp;一江春水向东流====================================================================================================上联:&nbsp;泼墨吟诗,银发人生添雅兴真实下联:&nbsp;手机短信,古稀老叟逐新潮预测下联:&nbsp;春风得意,万里千帆逐浪高====================================================================================================上联:&nbsp;刊岫展屏山,云凝罨画真实下联:&nbsp;平湖环镜槛,波漾空明预测下联:&nbsp;千年古邑,百花芳草淹春====================================================================================================上联:&nbsp;且向人间赊一醉真实下联:&nbsp;直如岛外泛孤舟预测下联:&nbsp;春风得意乐逍遥===================================================================================================="""

需要注意的地方:

  • 这里我们的输入不再是单条文本,而是文本对。

  • 我们需要构造编码器Input_ids,编码器attention_mask,解码器input_ids,解码器attention_mask。

  • 这里使用了一个技巧:采样生成,设置do_sample=True。如果你尝试设置它为False,你会发现生成的结果可能不是那么好。

  • 同样的这里使用的还是Bert的tokenizer,这里进行tokenizer的时候我们保留了bert的[CLS]和[SEP]。为了更直观的理解,我们使用另一种更直接的方法来生成结果:

    model = BartForConditionalGeneration.from_pretrained(model_path)model = model.to("cuda")model.eval()inputs = tokenizer(        texts,        padding="max_length",        truncation=True,        max_length=max_len,        return_tensors="pt",    )input_ids = inputs.input_ids.to(model.device)attention_mask = inputs.attention_mask.to(model.device)# 生成outputs = model.generate(input_ids,               attention_mask=attention_mask,               max_length=max_len,               do_sample=True,               pad_token_id=0,              eos_token_id=102)# 将token转换为文字output_str = tokenizer.batch_decode(outputs, skip_special_tokens=False)output_str = [s.replace(" ","") for s in output_str]for text, label, pred in zip(texts, labels, output_str):  print("上联:", text)  print("真实下联:", label)  print("预测下联:", pred)  print("="*100)

结果:

上联:&nbsp;几帧山水关秋路真实下联:&nbsp;无奈胭脂点绛唇预测下联:&nbsp;[SEP][CLS]春风送暖柳含烟[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================上联:&nbsp;许多心事懒收拾真实下联:&nbsp;大好青春莫撂荒预测下联:&nbsp;[SEP][CLS]无私奉献为人民[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================上联:&nbsp;谁同执手人间老真实下联:&nbsp;自愿并肩化外游预测下联:&nbsp;[SEP][CLS]清风明月是知音[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================上联:&nbsp;画地为牢封自步真实下联:&nbsp;齐天大圣悟空行预测下联:&nbsp;[SEP][CLS]月明何处不相逢[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================上联:&nbsp;布谷携春临五岳真实下联:&nbsp;流莺送喜到千家预测下联:&nbsp;[SEP][CLS]一壶老酒醉春风[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================上联:&nbsp;冤家宜解不宜结真实下联:&nbsp;穷寇定歼必定追预测下联:&nbsp;[SEP][CLS]风流人物不虚名[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================上联:&nbsp;汪伦情义人间少真实下联:&nbsp;法律条文格外繁预测下联:&nbsp;[SEP][CLS]万里江山万里春[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================上联:&nbsp;泼墨吟诗,银发人生添雅兴真实下联:&nbsp;手机短信,古稀老叟逐新潮预测下联:&nbsp;[SEP][CLS]和谐社会,和谐和谐幸福家[SEP]====================================================================================================上联:&nbsp;刊岫展屏山,云凝罨画真实下联:&nbsp;平湖环镜槛,波漾空明预测下联:&nbsp;[SEP][CLS]天下无双,人寿年丰[SEP][PAD][PAD][PAD]====================================================================================================上联:&nbsp;且向人间赊一醉真实下联:&nbsp;直如岛外泛孤舟预测下联:&nbsp;[SEP][CLS]不知何处有闲人[SEP][PAD][PAD][PAD][PAD][PAD]====================================================================================================
  • 我们设置skip_special_tokens=False,在生成时不忽略特殊token。

  • 以"无奈胭脂点绛唇"为例。输入[SEP],预测得到[CLS],输入[SEP][CLS]得到正常的文本,最后以[SEP]结尾。因为我们的encoder_input_ids和decoder_input_ids都是加了特殊符号的。当然你可以不加或者自定义使用其它的特殊符号。


到这里,你已经了解了transformers库中自带的模型及相关的一些任务了,特别是针对生成模型有了更深一层的了解,赶紧去试试吧。

最后附上相关的一些知识:

https://zhuanlan.zhihu.com/p/624845975

Part3参考

transformers.models.auto.modeling_auto — transformers 4.4.2 documentation (huggingface.co)