基于Keras的路透社新闻数据集多分类问题

阅读: 评论:0

基于Keras的路透社新闻数据集多分类问题

基于Keras的路透社新闻数据集多分类问题

原创不易,如果有转载需要的话,请在首行附上本文地址,谢谢。

第一步加载IMDB数据集,如若数据库加载不成功,这里提供一种解决方法:点开reuters 新闻分类数据集下载(喜欢的话给个小星星和follow一下),fork到自己的仓库中下载reuters.pnz,或者直接下载reuters.pnz。下载好后将reuters.pnz放入你的keras数据库中如:~/.keras/datasets/,即可正常运行

具体代码如下(复制到编译器中可直接运行):

部分代码有注释,便于读者理解


#加载reuters数据集
from keras.datasets import reuters
(train_data,train_labels),(test_data,test_labels)=reuters.load_data(num_words=10000)import numpy as np
def vectorize_sequences(sequences,dimension=10000):results&#s((len(sequences),dimension))for i,sequence in enumerate(sequences):results[i,sequence]&#urn results#将训练数据和测试数据向量化
x_train =vectorize_sequences(train_data)
x_test=vectorize_sequences(test_data)def to_one_hot (labels,dimension=46):results&#s((len(labels),dimension))for i,label in enumerate(labels):results[i,label]&#urn resultsone_hot_train_labels =to_one_hot(train_labels)
one_hot_test_labels =to_one_hot(test_labels)# =============================================================================
# #Keras 内置方法实现分类编码
#  from keras.utils.np_utils import to_categorical
#  one_hot_train_labels=to_categorical(train_labels)
#  one_hot_test_labels=to_categorical(test_labels)
# =============================================================================#编译模型
from keras import models
from keras import layers
model=models.Sequential()
model.add(layers.Dense(64,activation='relu',input_shape=(10000,)))
model.add(layers.Dense(64,activation='relu'))
model.add(layers.Dense(46,activation='softmax'))
modelpile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])# ============================================================================
#留出验证集
x_val=x_train[:1000]
partial_x_train=x_train[1000:]y_val=one_hot_train_labels[:1000]
partial_y_train=one_hot_train_labels[1000:]#训练模型
#history =model.fit(partial_x_train,partial_y_train,epochs=20,batch_size=512,validation_data=(x_val,y_val))# =============================================================================
# #绘制训练损失和验证损失
# import matplotlib.pyplot as plt
# loss =history.history['loss']
# val_loss=history.history['val_loss']
# 
# epochs=range(1,len(loss)+1)
# plt.plot(epochs,loss,'bo',label='Training loss')
# plt.plot(epochs,val_loss,'b',label='Validation loss')
# plt.title('Trainning and validation loss')
# plt.xlabel('Epochs')
# plt.ylabel('Loss')
# plt.legend()
# plt.show()
# =============================================================================# =============================================================================
# #绘制训练精度和验证精度
# plt.clf()
# acc=history.history['acc']
# val_acc=history.history['val_acc']
# plt.plot(epochs,acc,'bo',label='Trainning acc')
# plt.plot(epochs,val_acc,'b',label='Validation acc')
# plt.title('Training and validation accuracy')
# plt.xlabel('Epochs')
# plt.ylabel('Accuracy')
# plt.legend()
# plt.show()
# =============================================================================# =============================================================================
# #根据上图,可以看出第九轮后开始过拟合,从头开始训练一个网络,共九个轮次,注意运行此段代码时,注释前面部分代码
# model=models.Sequential()
# model.add(layers.Dense(64,activation='relu',input_shape=(10000,)))
# model.add(layers.Dense(64,activation='relu'))
# model.add(layers.Dense(46,activation='softmax'))
# 
# modelpile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
# model.fit(partial_x_train,partial_y_train,epochs=9,batch_size=512,validation_data=(x_val,y_val))
# =============================================================================results=model.evaluate(x_test,one_hot_test_labels)#在新数据上生成预测结果,每个元素的最大概率值即为类别,启动
predictions =model.predict(x_test)print(np.argmax(predictions[0]))

 

本文发布于:2024-02-02 04:44:02,感谢您对本站的认可!

本文链接:https://www.4u4v.net/it/170682024141432.html

版权声明:本站内容均来自互联网,仅供演示用,请勿用于商业和其他非法用途。如果侵犯了您的权益请与我们联系,我们将在24小时内删除。

标签:路透社   数据   新闻   Keras
留言与评论(共有 0 条评论)
   
验证码:

Copyright ©2019-2022 Comsenz Inc.Powered by ©

网站地图1 网站地图2 网站地图3 网站地图4 网站地图5 网站地图6 网站地图7 网站地图8 网站地图9 网站地图10 网站地图11 网站地图12 网站地图13 网站地图14 网站地图15 网站地图16 网站地图17 网站地图18 网站地图19 网站地图20 网站地图21 网站地图22/a> 网站地图23