• 周五. 4月 26th, 2024

5G编程聚合网

5G时代下一个聚合的编程学习网

热门标签

文本分类(五):transformers库BERT实战,基于BertForSequenceClassification

admin

11月 28, 2021

一、代码一

import pandas as pd
import codecs
from config.root_path import root
import os
from utils.data_process import get_label,text_preprocess
import json
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader, TensorDataset
import torch
import re
import numpy as np
from transformers import BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
import torch.nn as nn


class NewsDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    # 读取单个样本
    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(int(self.labels[idx]))
        return item

    def __len__(self):
        return len(self.labels)

# 精度计算
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

class EarlyStopper(object):

    def __init__(self, num_trials, save_path):
        self.num_trials = num_trials
        self.trial_counter = 0
        self.best_accuracy = 0
        self.save_path = save_path

    def is_continuable(self, model, accuracy):
        if accuracy > self.best_accuracy:
            self.best_accuracy = accuracy
            self.trial_counter = 0
            print("保存模型,指标:{}", accuracy)
            torch.save(model.state_dict(), self.save_path)
            return True
        elif self.trial_counter + 1 < self.num_trials:
            self.trial_counter += 1
            return True
        else:
            return False

class run_bert():

    def __init__(self):

        data_path = os.path.join(root, "data")
        self.train_path = os.path.join(data_path, "train.txt")
        self.val_path = os.path.join(data_path, "val.txt")
        self.test_path = os.path.join(data_path, "test.txt")
        code_label_path = os.path.join(root, "code_to_label.json")
        if not os.path.exists(code_label_path):
            get_label()
        with open(code_label_path, "r", encoding="utf8") as f:
            self.code_label = json.load(f)
        self.model_name = os.path.join(root, "chkpt", "bert-base-chinese")
        self.num_label = len(self.code_label)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.batch_size = 16

    def read_file(self, path):
        sentences = list()
        labels = list()
        with open(path, "r", encoding="utf8") as f:
            for fr in f.readlines():
                line = fr.strip().split("	")
                sentences.append(text_preprocess(line[0]))
                labels.append(self.code_label[line[1]][2])
        return sentences, labels

    def get_datas(self):
        train_s, train_l = self.read_file(self.train_path)
        val_s, val_l = self.read_file(self.val_path)
        test_s, test_l = self.read_file(self.test_path)
        return train_s, train_l, val_s, val_l, test_s, test_l

    def s_encoding(self, s):
        tokenizer = BertTokenizer.from_pretrained(self.model_name)
        encoding = tokenizer(s, truncation=True, padding=True, max_length=40)
        return encoding

    # 训练函数
    def train(self, model, train_loader, optim, device, scheduler, epoch, loss_fn):
        model.train()
        total_train_loss = 0
        iter_num = 0
        total_iter = len(train_loader)
        for batch in train_loader:
            # 正向传播
            optim.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            logits = outputs[1]
            loss = loss_fn(logits, labels)
            total_train_loss += loss.item()


            # 反向梯度信息
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            # 参数更新
            optim.step()
            scheduler.step()

            iter_num += 1
            if (iter_num % 10 == 0):
                print("epoth: %d, iter_num: %d, loss: %.4f, %.2f%%" % (
                epoch, iter_num, loss.item(), iter_num / total_iter * 100))

        print("Epoch: %d, Average training loss: %.4f" % (epoch, total_train_loss / len(train_loader)))

    def validation(self, model, val_dataloader, device):
        model.eval()
        total_eval_accuracy = 0
        total_eval_loss = 0
        for batch in val_dataloader:
            with torch.no_grad():
                # 正常传播
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)

            loss = outputs[0]
            logits = outputs[1]
            total_eval_loss += loss.item()
            logits = logits.detach().cpu().numpy()
            label_ids = labels.to('cpu').numpy()
            total_eval_accuracy += flat_accuracy(logits, label_ids)

        avg_val_accuracy = total_eval_accuracy / len(val_dataloader)
        print("Accuracy: %.4f" % (avg_val_accuracy))
        print("Average testing loss: %.4f" % (total_eval_loss / len(val_dataloader)))
        print("-------------------------------")
        return avg_val_accuracy

    def main(self):
        train_s, train_l, val_s, val_l, test_s, test_l  = self.get_datas()
        train_encoding = self.s_encoding(train_s)
        val_encoding = self.s_encoding(val_s)

        train_dataset = NewsDataset(train_encoding, train_l)
        val_dataset = NewsDataset(val_encoding, val_l)

        model = BertForSequenceClassification.from_pretrained(
                self.model_name, num_labels=self.num_label)
        model.to(self.device)
        train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
        optim = AdamW(model.parameters(), lr=2e-5)
        loss_fn = nn.CrossEntropyLoss()
        total_steps = len(train_loader) * 1
        scheduler = get_linear_schedule_with_warmup(optim,
                                                    num_warmup_steps=0,  # Default value in run_glue.py
                                                    num_training_steps=total_steps)
        early_stopper = EarlyStopper(num_trials=5, save_path=f'{os.path.join(root, "chkpt")}/{"bert_classification"}.pt')
        for epoch in range(100):
            print("------------Epoch: %d ----------------" % epoch)
            self.train(model, train_loader, optim, self.device, scheduler, epoch, loss_fn)
            acc = self.validation(model, val_dataloader, self.device)
            if not early_stopper.is_continuable(model, acc):
                print(f'validation: best auc: {early_stopper.best_accuracy}')
                break

        test_encoding = self.s_encoding(test_s)
        test_dataset = NewsDataset(test_encoding, test_l)
        test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)
        acc = self.validation(model, test_loader, self.device)
        print(f'test acc: {acc}')

if __name__ == '__main__':
    run_bert().main()

二、分类效果

模型准确率82%,效果不好。

 

发表回复

您的电子邮箱地址不会被公开。 必填项已用*标注