其他分享
首页 > 其他分享> > tensorflow.keras.datasets 中关于imdb.load_data的使用说明

tensorflow.keras.datasets 中关于imdb.load_data的使用说明

作者:互联网

在tensorflow2.x的keras中内置了7种类型的数据集:

数据集名称数据集描述
boston_housing波士顿房价数据
cifar1010种类别图片集
cifar100100种类别图片集
fashion_mnist10种时尚类别图片集
imdb电影评论情感分类数据集
mnist手写数字图片集
reuters路透社新闻主题分类数据集

这些数据的读取都可以使用load_data()方法。不过2种关于文本的数据集imdb和reuters比较特殊,他们的load_data中包含了过滤参数。本文将介绍imdb的load_data参数以及用法。
imdb.load_data的定义如下:

tf.keras.datasets.imdb.load_data(
    path='imdb.npz', num_words=None, skip_top=0, maxlen=None, seed=113,
    start_char=1, oov_char=2, index_from=3, **kwargs
)
from tensorflow.keras import datasets

(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print('第一个评论内容:',x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(num_words=100)
print("前100词频:",len(x100),' 第一个评论【100】:',len(x100[0]))
print('第一个评论内容【100】:',x100[0][0:10])

结果如下:

全部数据: 25000  第一个评论: 218
第一个评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
前100词频: 25000  第一个评论【100】: 218
第一个评论内容【100】: [1, 14, 22, 16, 43, 2, 2, 2, 2, 65]

对比可以发现,索引大于100的都被2替代了。

from tensorflow.keras import datasets

(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print('第一个评论内容:',x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(skip_top=100)
print("跳过前100词频:",len(x100),' 第一个评论【100】:',len(x100[0]))
print('第一个评论内容【100】:',x100[0][0:10])

结果如下:

全部数据: 25000  第一个评论: 218
第一个评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
跳过前100词频: 25000  第一个评论【100】: 218
第一个评论内容【100】: [2, 2, 2, 2, 2, 530, 973, 1622, 1385, 2]

对比可以发现,索引小于100的都被2替代了。

from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(maxlen=100)
print("长度小于100:",len(x100),' 第一个评论【100】:',len(x100[0]))

结果:

全部数据: 25000  第一个评论: 218
长度小于100词频: 5736  第一个评论【100】: 43

可以看出定义了maxlen之后,读入的数据少了。

from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print("第一条评论内容:",x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(start_char=100)
print("起始索引:",len(x100),' 第一个评论【100】:',len(x100[0]))
print("第一条评论内容【100】:",x100[0][0:10])

结果如下:

全部数据: 25000  第一个评论: 218
第一条评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
起始索引: 25000  第一个评论【100】: 218
第一条评论内容【100】: [100, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]

对比可以发现,评论开始的数值被换为100了。

from tensorflow.keras import datasets
(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print("第一条评论内容:",x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(oov_char=100,skip_top=20)
print("替换索引=100:",len(x100),' 第一个评论【100】:',len(x100[0]))
print("第一条评论内容【100】:",x100[0][0:10])

结果如下:

全部数据: 25000  第一个评论: 218
第一条评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
替换索引=100: 25000  第一个评论【100】: 218
第一条评论内容【100】: [100, 100, 22, 100, 43, 530, 973, 1622, 1385, 65]

可以发现替换索引是100。起始索引也变为100了。即使定义了start_char也没有作用,这一点一定要注意。

from tensorflow.keras import datasets

(x,y),(tx,ty) = datasets.imdb.load_data()
print("全部数据:",len(x),' 第一个评论:',len(x[0]))
print("第一条评论内容:",x[0][0:10])
(x100,y100),(tx100,ty100) = datasets.imdb.load_data(index_from=100)
print("index_from=100:",len(x100),' 第一个评论【100】:',len(x100[0]))
print("第一条评论内容【100】:",x100[0][0:10])

结果:

全部数据: 25000  第一个评论: 218
第一条评论内容: [1, 14, 22, 16, 43, 530, 973, 1622, 1385, 65]
index_from=100: 25000  第一个评论【100】: 218
第一条评论内容【100】: [1, 111, 119, 113, 140, 627, 1070, 1719, 1482, 162]

对比可以发现,每个单词都被增加了100-3=97当index_from =100的时候。之所以要减去3是因为默认参数index_from=3,因此不带任何参数的load_data()实际上是在原始的索引上增加了3。

源代码分析

从Github上可以看到此函数的代码:

  if 'nb_words' in kwargs:
    logging.warning('The `nb_words` argument in `load_data` '
                    'has been renamed `num_words`.')
    num_words = kwargs.pop('nb_words')
  if kwargs:
    raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))

  origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
  path = get_file(
      path,
      origin=origin_folder + 'imdb.npz',
      file_hash=
      '69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f')
  with np.load(path, allow_pickle=True) as f:
    x_train, labels_train = f['x_train'], f['y_train']
    x_test, labels_test = f['x_test'], f['y_test']

  rng = np.random.RandomState(seed)
  indices = np.arange(len(x_train))
  rng.shuffle(indices)
  x_train = x_train[indices]
  labels_train = labels_train[indices]

  indices = np.arange(len(x_test))
  rng.shuffle(indices)
  x_test = x_test[indices]
  labels_test = labels_test[indices]

  if start_char is not None:
    x_train = [[start_char] + [w + index_from for w in x] for x in x_train]
    x_test = [[start_char] + [w + index_from for w in x] for x in x_test]
  elif index_from:
    x_train = [[w + index_from for w in x] for x in x_train]
    x_test = [[w + index_from for w in x] for x in x_test]

  if maxlen:
    x_train, labels_train = _remove_long_seq(maxlen, x_train, labels_train)
    x_test, labels_test = _remove_long_seq(maxlen, x_test, labels_test)
    if not x_train or not x_test:
      raise ValueError('After filtering for sequences shorter than maxlen=' +
                       str(maxlen) + ', no sequence was kept. '
                       'Increase maxlen.')

  xs = np.concatenate([x_train, x_test])
  labels = np.concatenate([labels_train, labels_test])

  if not num_words:
    num_words = max(max(x) for x in xs)

  # by convention, use 2 as OOV word
  # reserve 'index_from' (=3 by default) characters:
  # 0 (padding), 1 (start), 2 (OOV)
  if oov_char is not None:
    xs = [
        [w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
    ]
  else:
    xs = [[w for w in x if skip_top <= w < num_words] for x in xs]

  idx = len(x_train)
  x_train, y_train = np.array(xs[:idx]), np.array(labels[:idx])
  x_test, y_test = np.array(xs[idx:]), np.array(labels[idx:])

  return (x_train, y_train), (x_test, y_test)

seed是用来初始随机数的:

 rng = np.random.RandomState(seed)

start_char是额外添加的:

x_train = [[start_char] + [w + index_from for w in x] for x in x_train]

下面这段代码是取得评论的最大词频:

  if not num_words:
    num_words = max(max(x) for x in xs)

这段代码实现了oov_char替换:

  if oov_char is not None:
    xs = [
        [w if (skip_top <= w < num_words) else oov_char for w in x] for x in xs
    ]
  else:
    xs = [[w for w in x if skip_top <= w < num_words] for x in xs]

需要注意的是,由于oov_char是全替换索引,也包括start_char。因此在更改oov_char的时候,还要注意start_char也被修改了。这应该是个小bug。

标签:load,datasets,x100,keras,len,char,train,评论,100
来源: https://blog.csdn.net/weixin_42272768/article/details/112093163