def create_dataset():np.random.seed(1)m = 400 # 数据量N = int(m/2) # 每个标签的实例数D = 2 # 数据维度X = np.zeros((m,D)) # 数据矩阵Y = np.zeros((m,1), dtype='uint8') # 标签维度a = 4 for j in range(2):ix = range(N*j,N*(j+1))t = np.linspace(j*3.12,(j+1)*3.12,N) + np.random.randn(N)*0.2 # thetar = a*np.sin(4*t) + np.random.randn(N)*0.2 # radiusX[ix] = np.c_[r*np.sin(t), r*np.cos(t)]Y[ix] = jX = X.TY = Y.Treturn X, Y
这个函数 create_dataset
生成一个带有二维数据和对应标签的螺旋形数据集。螺旋形数据集是用于分类任务的典型数据集之一,尤其在测试复杂分类模型(如神经网络)时经常使用。以下是对这个函数的详细解释:
1. 参数与初始化
np.random.seed(1)
m = 400 # 数据集的总数量
N = int(m/2) # 每个类别(标签)的样本数量
D = 2 # 数据的维度(二维数据)
X = np.zeros((m,D)) # 初始化数据矩阵X,大小为 m x D
Y = np.zeros((m,1), dtype='uint8') # 初始化标签矩阵Y,大小为 m x 1,数据类型为无符号8位整数
a = 4 # 控制数据螺旋的半径
np.random.seed(1)
:确保随机数的可重复性,每次运行生成的数据集是相同的。m = 400
:数据集的总样本数。N = int(m/2)
:每个类别的样本数为一半(即 200 个样本属于类别 0,另外 200 个样本属于类别 1)。D = 2
:数据的维度为 2(表示二维数据)。X = np.zeros((m,D))
:初始化大小为 m × D m \times D m×D 的零矩阵,用于存储样本特征。Y = np.zeros((m,1), dtype='uint8')
:初始化大小为 m × 1 m \times 1 m×1 的零矩阵,用于存储样本标签。
2. 生成数据
for j in range(2):ix = range(N*j,N*(j+1)) # 生成当前类的索引范围t = np.linspace(j*3.12,(j+1)*3.12,N) + np.random.randn(N)*0.2 # 角度 thetar = a*np.sin(4*t) + np.random.randn(N)*0.2 # 半径 rX[ix] = np.c_[r*np.sin(t), r*np.cos(t)] # 生成二维坐标 (x1, x2)Y[ix] = j # 为当前生成的数据赋予标签 j
这个循环迭代两次(一次生成类别 0 的数据,一次生成类别 1 的数据),生成螺旋形数据。具体步骤如下:
-
ix = range(N*j, N*(j+1))
:当前类别样本的索引范围。第一次循环时生成类别 0 的样本,索引为 0 到 199;第二次循环时生成类别 1 的样本,索引为 200 到 399。 -
t = np.linspace(j*3.12, (j+1)*3.12, N) + np.random.randn(N)*0.2
:生成从 j × 3.12 j \times 3.12 j×3.12 到 ( j + 1 ) × 3.12 (j+1) \times 3.12 (j+1)×3.12 的角度theta
,并在每个点上添加一些随机噪声(np.random.randn(N)*0.2
)。这些角度用于控制螺旋形的弯曲程度。 -
r = a*np.sin(4*t) + np.random.randn(N)*0.2
:生成半径r
,即样本离原点的距离。半径是一个基于sin(4t)
的函数,并添加了随机噪声(np.random.randn(N)*0.2
)以增加数据集的多样性。这个函数生成螺旋的曲线形状。 -
X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
:利用极坐标 ( r , t ) (r, t) (r,t) 计算二维笛卡尔坐标 ( x 1 , x 2 ) (x_1, x_2) (x1,x2),并将其存储在数据矩阵X
中。 -
Y[ix] = j
:为当前类别的样本赋值为j
(即当前类别的标签)。
3. 返回值
X = X.T
Y = Y.T
return X, Y
X.T
:转置后的数据矩阵,输出大小为 2 × m 2 \times m 2×m,表示 2 个特征和 m m m 个样本。Y.T
:转置后的标签矩阵,输出大小为 1 × m 1 \times m 1×m,每个样本对应一个标签。
4. 总结
该函数 create_dataset()
生成了一个螺旋形数据集,数据集具有以下特点:
- 数据集分为两个类别,每个类别各有 200 个样本。
- 数据点以螺旋形状分布,这对线性分类器(如线性支持向量机、感知机等)来说是一个较为复杂的分类任务,因为螺旋形数据通常是非线性可分的。
- 随机噪声的加入使得数据更具挑战性,有助于测试复杂模型(如神经网络)的分类能力。