其他分享
首页 > 其他分享> > paddle13-ERNIE2.0源代码解析2-tokenizing_ernie相关

paddle13-ERNIE2.0源代码解析2-tokenizing_ernie相关

作者:互联网

tokenizing相关代码

#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
from __future__ import absolute_import
from __future__ import print_function
from __future__ import unicode_literals

import sys
import os
import six
import re
import logging
import tempfile
from pathlib import Path
from functools import partial
if six.PY2:
    from pathlib2 import Path
else:
    from pathlib import Path

from tqdm import tqdm
import numpy as np

from ernie.file_utils import _fetch_from_remote
import io

open = partial(io.open, encoding='utf8')

log = logging.getLogger(__name__)

_max_input_chars_per_word = 100
# @ErnieTokenizer() 
def _wordpiece(token, vocab, unk_token, prefix='##', sentencepiece_prefix=''):
    """ wordpiece: helloworld => [hello, ##world] """
    '''
    |  第一轮 |
    helloworld
    |   |
    world

    return : 
        1. sub_tokens: 切分次结果列表
        2. sub_pos: 对应得起始位置列表
    '''
    chars = list(token)
    if len(chars) > _max_input_chars_per_word:
        return [unk_token], [(0, len(chars))]

    is_bad = False
    start = 0
    sub_tokens = []
    sub_pos = []
    while start < len(chars):
        end = len(chars)
        cur_substr = None

        # 最长匹配词典中一个单词:后向最长匹配
        while start < end:
            substr = "".join(chars[start:end])
            if start == 0:
                substr = sentencepiece_prefix + substr
            if start > 0:
                substr = prefix + substr
            if substr in vocab:
                cur_substr = substr
                break
            end -= 1

        # 输入句子从start, end 不能匹配到 词典中单词,则整句退出;
        if cur_substr is None:
            is_bad = True
            break
        
        # result
        sub_tokens.append(cur_substr)
        sub_pos.append((start, end))
        # 从左往右 匹配;
        start = end
    if is_bad:
        return [unk_token], [(0, len(chars))]
    else:
        return sub_tokens, sub_pos
    bce = 'https://ernie-github.cdn.bcebos.com/'
    resource_map = {
        'ernie-1.0': bce + 'model-ernie1.0.1.tar.gz',
        'ernie-2.0-en': bce + 'model-ernie2.0-en.1.tar.gz',
        'ernie-2.0-large-en': bce + 'model-ernie2.0-large-en.1.tar.gz',
        'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz',
        'ernie-gen-base-en': bce + 'model-ernie-gen-base-en.1.tar.gz',
        'ernie-gen-large-en': bce + 'model-ernie-gen-large-en.1.tar.gz',
    }

    # pretrained vocab_dict loading
    @classmethod
    def from_pretrained(cls,
                        pretrain_dir_or_url,
                        force_download=False,
                        **kwargs):
        if not Path(pretrain_dir_or_url).exists() and str(
                pretrain_dir_or_url) in cls.resource_map:
            url = cls.resource_map[str(pretrain_dir_or_url)]
            log.info('get pretrain dir from %s' % url)
            pretrain_dir = _fetch_from_remote(
                url, force_download=force_download)
        else:
            log.info('pretrain dir %s not in %s, read from local' %
                     (pretrain_dir_or_url, repr(cls.resource_map)))
            pretrain_dir = Path(pretrain_dir_or_url)
        if not pretrain_dir.exists():
            raise ValueError('pretrain dir not found: %s' % pretrain_dir)
        vocab_path = pretrain_dir / 'vocab.txt'
        if not vocab_path.exists():
            raise ValueError('no vocab file in pretrain dir: %s' %
                             pretrain_dir)
        vocab_dict = {
            j.strip().split('\t')[0]: i
            for i, j in enumerate(
                vocab_path.open(encoding='utf8').readlines())
        }
        t = cls(vocab_dict, **kwargs)
        return t

    def __init__(self,
                 vocab,                 # dict
                 unk_token='[UNK]',
                 sep_token='[SEP]',
                 cls_token='[CLS]',
                 pad_token='[PAD]',
                 mask_token='[MASK]',
                 wordpiece_prefix='##',
                 sentencepiece_prefix='',
                 lower=True,
                 encoding='utf8',
                 special_token_list=[]):
        if not isinstance(vocab, dict):
            raise ValueError('expect `vocab` to be instance of dict, got %s' %
                             type(vocab))
        self.vocab = vocab
        self.lower = lower
        self.prefix = wordpiece_prefix
        self.sentencepiece_prefix = sentencepiece_prefix
        self.pad_id = self.vocab[pad_token]
        self.cls_id = cls_token and self.vocab[cls_token]
        self.sep_id = sep_token and self.vocab[sep_token]
        self.unk_id = unk_token and self.vocab[unk_token]
        self.mask_id = mask_token and self.vocab[mask_token]
        self.unk_token = unk_token
        special_tokens = {
            pad_token, cls_token, sep_token, unk_token, mask_token
        } | set(special_token_list)
        pat_str = ''
        for t in special_tokens:
            if t is None:
                continue
            pat_str += '(%s)|' % re.escape(t)
        # 匹配模式: 任意字符串 or 数字串, 单个单词
        pat_str += r'([a-zA-Z0-9]+|\S)'
        log.debug('regex: %s' % pat_str)
        self.pat = re.compile(pat_str)
        self.encoding = encoding

    # 用到encode 函数,调用wordpiece分词,并将结果存储
    def tokenize(self, text):
        if len(text) == 0:
            return []
        if six.PY3 and not isinstance(text, six.string_types):
            text = text.decode(self.encoding)
        if six.PY2 and isinstance(text, str):
            text = text.decode(self.encoding)

        res = []
        # 正则表达式匹配: text中所有的 中文字、数字or英文串、特殊字符串
        for match in self.pat.finditer(text):
            # match_group : 一次匹配结果;
            match_group = match.group(0)

            # 不懂: match.group() 和 match.groups()[-1] 有什么区别吗?
            # group() 整个正则表达式匹配结果
            # groups() 正则表达式中所有分组匹配结果
            # 中文汉字怎么切分的呢? 直接加进去吗? 怎么区分出来的? 
            # 需要看看 中文模型带的 词典中 汉字单词;
            if match.groups()[-1]:
                if self.lower:
                    match_group = match_group.lower()
                
                # wordpiece 
                # words: 分词结果 list
                words, _ = _wordpiece(
                    match_group,
                    vocab=self.vocab,
                    unk_token=self.unk_token,
                    prefix=self.prefix,
                    sentencepiece_prefix=self.sentencepiece_prefix)
            else:
                # 这是什么情形?
                words = [match_group]
            res += words
        return res

    # 用到 encode 函数
    def convert_tokens_to_ids(self, tokens):
        return [self.vocab.get(t, self.unk_id) for t in tokens]

    # 用到 encode 函数 
    def truncate(self, id1, id2, seqlen):
        len1 = len(id1)
        len2 = len(id2)
        half = seqlen // 2
        if len1 > len2:
            len1_truncated, len2_truncated = max(half, seqlen - len2), min(
                half, len2)
        else:
            len1_truncated, len2_truncated = min(half, seqlen - len1), max(
                half, seqlen - len1)
        return id1[:len1_truncated], id2[:len2_truncated]

    # 用到 encode 函数
    def build_for_ernie(self, text_id, pair_id=[]):
        """build sentence type id, add [CLS] [SEP]"""
        text_id_type = np.zeros_like(text_id, dtype=np.int64)
        ret_id = np.concatenate([[self.cls_id], text_id, [self.sep_id]], 0)
        ret_id_type = np.concatenate([[0], text_id_type, [0]], 0)

        # sentenceB
        if len(pair_id):
            pair_id_type = np.ones_like(pair_id, dtype=np.int64)
            ret_id = np.concatenate([ret_id, pair_id, [self.sep_id]], 0)
            ret_id_type = np.concatenate([ret_id_type, pair_id_type, [1]], 0)
        return ret_id, ret_id_type

    def encode(self, text, pair=None, truncate_to=None):
        # 
        text_id = np.array(
            self.convert_tokens_to_ids(self.tokenize(text)), dtype=np.int64)
        # 
        text_id_type = np.zeros_like(text_id, dtype=np.int64)
        if pair is not None:
            pair_id = np.array(
                self.convert_tokens_to_ids(self.tokenize(pair)),
                dtype=np.int64)
        else:
            pair_id = []
        if truncate_to is not None:
            text_id, pair_id = self.truncate(text_id, [] if pair_id is None
                                             else pair_id, truncate_to)
        # token_id, sentence_type_id
        ret_id, ret_id_type = self.build_for_ernie(text_id, pair_id)
        return ret_id, ret_id_type
# jieba 分词 + sentencePiece
class ErnieTinyTokenizer(ErnieTokenizer):
    bce = 'https://ernie-github.cdn.bcebos.com/'
    resource_map = {'ernie-tiny': bce + 'model-ernie_tiny.1.tar.gz'}

    @classmethod
    def from_pretrained(cls,
                        pretrain_dir_or_url,
                        force_download=False,
                        **kwargs):
        if not Path(pretrain_dir_or_url).exists() and str(
                pretrain_dir_or_url) in cls.resource_map:
            url = cls.resource_map[str(pretrain_dir_or_url)]
            log.info('get pretrain dir from %s' % url)
            pretrain_dir = _fetch_from_remote(url, force_download)
        else:
            log.info('pretrain dir %s not in %s, read from local' %
                     (pretrain_dir_or_url, repr(cls.resource_map)))
            pretrain_dir = Path(pretrain_dir_or_url)
        if not pretrain_dir.exists():
            raise ValueError('pretrain dir not found: %s' % pretrain_dir)
        vocab_path = pretrain_dir / 'vocab.txt'
        sp_model_path = pretrain_dir / 'subword/spm_cased_simp_sampled.model'

        if not vocab_path.exists():
            raise ValueError('no vocab file in pretrain dir: %s' %
                             pretrain_dir)
        vocab_dict = {
            j.strip().split('\t')[0]: i
            for i, j in enumerate(
                vocab_path.open(encoding='utf8').readlines())
        }

        t = cls(vocab_dict, sp_model_path, **kwargs)
        return t

    def __init__(self, vocab, sp_model_path, **kwargs):
        super(ErnieTinyTokenizer, self).__init__(vocab, **kwargs)
        import sentencepiece as spm
        import jieba as jb
        self.sp_model = spm.SentencePieceProcessor()
        self.window_size = 5
        self.sp_model.Load(sp_model_path)
        self.jb = jb

    def cut(self, sentence):
        return self.jb.cut(sentence)

    def tokenize(self, text):
        if len(text) == 0:
            return []
        if not isinstance(text, six.string_types):
            text = text.decode(self.encoding)
        if self.lower:
            text = text.lower()

        res = []
        for match in self.cut(text):
            res += self.sp_model.EncodeAsPieces(match)
        return res

标签:paddle13,vocab,ERNIE2.0,self,pretrain,token,源代码,id,dir
来源: https://www.cnblogs.com/zhangtao-0001/p/14737544.html