领域命名实体NER实现:Bert+BiLSTM+CRF

发布时间:2022-09-27 10:30

领域命名实体NER实现:Bert+BiLSTM+CRF_第1张图片

以前通过模板规则的方式进行命名实体的提取,优点是提取速度非常高,但模板规则存在冲突的情况,尝试过使用百度LAC通过词性模板规则进行命名实体的提取,好处是少量规则可以覆盖大部分情况,但也存在规则冲突的情况。本文尝试采用Bert+BiLSTM+CRF的方式进行命名实体的提取。使用Bert的好处是能够学习到语料的语义特征,BiLSTM能学习到词之间较长的上下文关系,CRF能纠正BiLSTM预测的顺序错误。Bert的好处是准确率非常高,缺点也很明显,推理速度低,可以通过部署的方式来提升推理性能,如:使用ONNX 运行环境。

主要步骤如下:

1)准备标注语料(自行准备了224个标注),生成和人民日报语料一样的格式(语料生成代码来自互联网),可以自定义领域命名实体;

#生成的训练语料,一个字一行,格式同人民日报语料
import re

# txt2ner_train_data turn label str into ner trainable data
# s :labeled str  eg.'我来到[@1999年#YEAR*]的[@上海#LOC*]的[@东华大学#SCHOOL*]'
# save_path: ner_trainable_txt name
def str2ner_train_data(s, save_path):
    ner_data = []
    result_1 = re.finditer(r'\[\@', s)
    result_2 = re.finditer(r'\*\]', s)
    begin = []
    end = []
    for each in result_1:
        begin.append(each.start())
    for each in result_2:
        end.append(each.end())
    print(len(begin) ,len(end))
    assert len(begin) == len(end)
    i = 0
    j = 0
    while i < len(s):
        if i not in begin:
            ner_data.append([s[i], 'O'])
            i = i + 1
        else:
            ann = s[i + 2:end[j] - 2]
            entity, ner = ann.rsplit('#')
            if (len(entity) == 1):
                ner_data.append([entity, 'B-' + ner])
                # ner_data.append([entity, 'S-' + ner])
            else:
                if (len(entity) == 2):
                    ner_data.append([entity[0], 'B-' + ner])
                    ner_data.append([entity[1], 'I-' + ner])
                    # ner_data.append([entity[1], 'E-' + ner])
                else:
                    ner_data.append([entity[0], 'B-' + ner])
                    for n in range(1, len(entity)):
                        ner_data.append([entity[n], 'I-' + ner])
                    # ner_data.append([entity[-1], 'E-' + ner])

            i = end[j]
            j = j + 1

    f = open(save_path, 'a', encoding='utf-8')
    for each in ner_data:
        f.write(each[0] + ' ' + str(each[1]))
        if each[0] == '。' or each[0] == '?' or each[0] == '!':
            f.write('\n')
            f.write('\n')
        else:
            f.write('\n')
    f.close()


# txt2ner_train_data turn label str into ner trainable data
# file_path :labeled multi lines' txt  eg.'我来到[@1999年#YEAR*]的[@上海#LOC*]的[@东华大学#SCHOOL*]'
# save_path: ner_trainable_txt name
def txt2ner_train_data(file_path, save_path):
    fr = open(file_path, 'r', encoding='utf-8')
    lines = fr.readlines()
    s = ''
    for line in lines:
        line = line.replace('\n', '')
        line = line.replace(' ', '')
        s = s + line
    fr.close()
    str2ner_train_data(s, save_path)


if (__name__ == '__main__'):
    train_path = './train.txt' #生成的训练语料,一个字一行,格式同人民日报语料
    corpus_path = './middle_corpus.txt'#根据领域特征标注语料,可以自定义NER标签,不限于PER(人名),LOC(地名),ORG(机构名)
    txt2ner_train_data(corpus_path, train_path)
# 读取自己的预料’
train_path = './train.txt'
test_path = './test.txt'

def get_sequenct_tagging_data(file_path):
    data_x, data_y = [], []

    with open(file_path, 'r', encoding='utf-8') as f:
        lines = f.read().splitlines()

        x, y = [], []
        for line in lines:
            rows = line.split(' ')
            if len(rows) == 1:
                data_x.append(x)
                data_y.append(y)
                x = []
                y = []
            else:
                x.append(rows[0])
                y.append(rows[1])
    return data_x, data_y

train_x, train_y = get_sequenct_tagging_data(train_path)
validate_x, validate_y = get_sequenct_tagging_data(test_path)

2)使用kashgari2.0.1用于快速使用模型进行训练,包括使用Bert作为特征提取,使用中文预训练模型chinese_L-12_H-768_A-12(需要自行下载到本地);

领域命名实体NER实现:Bert+BiLSTM+CRF_第2张图片领域命名实体NER实现:Bert+BiLSTM+CRF_第3张图片

3)模型的保存与装载;

领域命名实体NER实现:Bert+BiLSTM+CRF_第4张图片

4)使用模型进行推理,推理效果相当不错,比百度LAC的效果好。

领域命名实体NER实现:Bert+BiLSTM+CRF_第5张图片

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号