前两讲我们学习了使用tensorflow原生代码搭建神经网络,本讲主要学习使用Tensorflow API:tf.keras搭建神经网络
一、搭建网络八股Sequential
六步法:
1.import:import 相关模块,如 import tensorflow as tf
2.train, test:指定输入网络的训练集和测试集,如指定训练集的输入 x_train 和标签
y_train,测试集的输入 x_test 和标签 y_test。
3.model = tf.keras.models.Sequential:逐层搭建网络结构
4.model.compile:在 model.compile()中配置训练方法,选择训练时使用的优化器、损失
函数和最终评价指标。
5.model.fit:在 model.fit()中执行训练过程,告知训练集和测试集的输入值和标签、
每个 batch 的大小(batchsize)和数据集的迭代次数(epoch)
6.model.summary:使用 model.summary()打印网络结构,统计参数数目。
model = tf.keras.models.Sequential的使用:
model.compile的使用
注:from_logits=False:神经网络末端如果使用了softmax函数,输出为概率分布而不是原始输出,from_logits就为false,否则为True
model.fit()的使用
model.summary()的使用
二、搭建网络八股class
用Sequential能搭建上层输入就是下层输出的顺序网络结构,但是无法写出一些带有跳连的非顺序网络结构,这个时候我们可以选择用类class搭建神经网络结构。
class的使用 :
对比 Sequential和class搭建神经网络的过程:
以实现鸢尾花分类为例
Sequential
import tensorflow as tf
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)model = tf.keras.models.Sequential([tf.keras.layers.Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())
])model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)model.summary()
class
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as npx_train = datasets.load_iris().data
y_train = datasets.load_iris().targetnp.random.seed(116)
np.random.shuffle(x_train)
np.random.seed(116)
np.random.shuffle(y_train)
tf.random.set_seed(116)class IrisModel(Model):def __init__(self):super(IrisModel, self).__init__()self.d1 = Dense(3, activation='softmax', kernel_regularizer=tf.keras.regularizers.l2())def call(self, x):y = self.d1(x)return ymodel = IrisModel()model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20)
model.summary()
三、MNIST数据集 —手写数字识别训练
1.数据集的介绍
(1)MNIST数据集:
提供 6万张 28*28 像素点的0~9手写数字图片和标签,用于训练。
提供 1万张 28*28 像素点的0~9手写数字图片和标签,用于测试。
(2)导入MNIST数据集:
mnist = tf.keras.datasets.mnist
(x_train, y_train) , (x_test, y_test) = mnist.load_data()
(3)作为输入特征,输入神经网络时,将数据拉伸为一维数组:
tf.keras.layers.Flatten( )
[ 0 0 0 48 238 252 252 …… …… …… 253 186 12 0 0 0 0 0]
注:不知道这里大家有没有这样一个疑问,为什么鸢尾花的数据集不需要拉伸:
原因:鸢尾花数据集不需要拉直为一维是因为它的特征已经是数值型的,可以直接用于机器学习模型的训练和预测。而手写数字数据需要拉直为一维是因为它们的原始数据是图像形式的,需要通过转换才能被机器学习算法处理。
(4)观察数据集
2.代码实现书写数字识别
import tensorflow as tfmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0model = tf.keras.models.Sequential([tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(10, activation='softmax')
])model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Modelmnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0class MnistModel(Model):def __init__(self):super(MnistModel, self).__init__()self.flatten = Flatten()self.d1 = Dense(128, activation='relu')self.d2 = Dense(10, activation='softmax')def call(self, x):x = self.flatten(x)x = self.d1(x)y = self.d2(x)return ymodel = MnistModel()model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),metrics=['sparse_categorical_accuracy'])model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()
后面还有FASHION数据集数据集,与MNIST数据集处理方式类似,就不再赘述。