UniLM - 用于摘要的统一语言模型

数据挖掘 nlp 自动总结
2022-02-27 13:47:56

UniLM 声称是总结任务的最佳方法。但是 README.md 或任何其他博客中似乎没有任何教程或操作方法部分。我究竟如何使用这个最先进的库来生成抽象摘要?

Github 链接

PS NLP 的新手。对不起,如果这是一个愚蠢的问题。

1个回答

这是你应该做的

  1. 准备你的数据集:按照论文中描述的类似说明,预处理你的数据集。这将是您的主要任务,因为在此之后您只需要微调模型。如果没有数据集,可以使用本研究论文中使用的数据集,可以从这里下载。

  2. 下载预训练模型或者您可以选择从提供的微调模型检查点开始(来自链接)。您必须检查哪个版本的模型最适合您的数据集。如果您为摘要任务选择微调模型,并且您的数据集类似于CNN/DailyMail 数据集 [37]Gigaword [36],则可以跳过微调。

  3. 微调模型:在这一步中,您将使用 Github 存储库的自述文件中提到的命令。请注意,有些参数应该根据您在上一步中下载的语言模型。根据数据集的大小,您可以在以下命令中更改时期数。您还应该注意,这将需要 GPU。存储库自述文件建议使用 2 或 4 个 v100-32G GPU 卡来微调模型。

    OUTPUT_DIR=/{path_of_fine-tuned_model}/
    MODEL_RECOVER_PATH=/{path_of_pre-trained_model}/unilmv1-large-cased.bin
    export PYTORCH_PRETRAINED_BERT_CACHE=/{tmp_folder}/bert-cased-pretrained-cache
    export CUDA_VISIBLE_DEVICES=0,1,2,3
    python biunilm/run_seq2seq.py --do_train --fp16 --amp --num_workers 0 \
      --bert_model bert-large-cased --new_segment_ids --tokenized_input \
      --data_dir ${DATA_DIR} \
      --output_dir ${OUTPUT_DIR}/bert_save \
      --log_dir ${OUTPUT_DIR}/bert_log \
      --model_recover_path ${MODEL_RECOVER_PATH} \
      --max_seq_length 192 --max_position_embeddings 192 \
      --trunc_seg a --always_truncate_tail --max_len_a 0 --max_len_b 64 \
      --mask_prob 0.7 --max_pred 48 \
      --train_batch_size 128 --gradient_accumulation_steps 1 \
      --learning_rate 0.00003 --warmup_proportion 0.1 --label_smoothing 0.1 \
      --num_train_epochs 30
  1. 评估您的模型biunilm/decode_seq2seq.py用于解码(预测评估数据集的输出)并使用提供的评估脚本来评估经过训练的模型。

  2. 使用训练好的模型:为了使用这个模型进行预测,您可以简单地编写自己的 python 代码:

    • 使用文件中使用pytorch_pretrained_bert的库加载 Pytorch 预训练模型decode_seq2seq.py
    • 标记您的输入
    • 预测输出并对输出进行去标记化。

这是您可以使用的逻辑:

model = BertForSeq2SeqDecoder.from_pretrained(long_list_of_arguments)
batch = seq2seq_loader.batch_list_to_batch_tensors(input_batch)
input_ids, token_type_ids, position_ids, input_mask, mask_qkv, task_idx = batch
traces = model(input_ids, token_type_ids,position_ids, input_mask, task_idx=task_idx, mask_qkv=mask_qkv)

请注意,这不是完整的逻辑。此代码仅显示 github 存储库代码如何处理保存的模型并使用它进行预测。用于traces将 id 转换为令牌并对输出令牌进行去令牌化(如此处代码中使用的那样去标记化步骤是必要的,因为输入序列被 WordPiece 标记为子字单元。

这里的参考是加载预训练模型的代码。您可以通过循环来了解逻辑并在您的情况下实现它。我希望这有帮助。