前期准备

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 以下代码是导入这个数据集
# 注意data返回的是一个字典
from sklearn.datasets import load_iris
import numpy as np
import pandas as pd

data = load_iris()
x = data['data']
y = data['target']
iris = pd.DataFrame(x,columns = ['Sepal.Length', 'Sepal.Width', 'Petal.Length', 'Petal.Width']) # 加一下列名
iris['Species'] = y

# 提取数据,提取前两列的数据作为例子
X = iris[['Sepal.Length','Sepal.Width']]
# 标准化一下
Xstd = (X - X.mean()) / X.std()
Xstd

确定K值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
'''
手肘法:
判断标准——明显的拐点,曲率最大的地方
'''
from sklearn.cluster import KMeans # 导入聚类模块

model = KMeans(n_clusters = 2) # 设置聚成两类
model.fit(Xstd)

# 取出所有聚类中心均值向量的总和
model.inertia_

result_list = [] # 存储所有聚类中心均值向量的总和
# 假设分类类别是2-11
for i in range(2,12):
# 新建聚类模型,改变聚类中心个数
model = KMeans(n_clusters = i,random_state = 1) # 设置种子,防止种子带来的随机性
# 训练模型
model.fit(Xstd)
# 取出所有聚类中心均值向量的总和
result_list.append(model.inertia_)

# 手肘图
import matplotlib.pyplot as plt
xs = list(range(2,12))
plt.plot(xs, result_list)
plt.show() # 找到变平滑的那个k

'''
轮廓系数:
轮廓系数越大越好。
'''
from sklearn import metrics
# 假设分类类别是2-11
for i in range(2,12):
# 新建聚类模型,改变聚类中心个数
model = KMeans(n_clusters = i,random_state = 1)
# 训练模型
model.fit(Xstd)
# 获取kmeans聚类的结果
labels = model.labels_
# 计算轮廓系数
res = metrics.silhouette_score(Xstd,labels)
print(i,':',res)

# 综合两种方式选择一个合适的K。

'''
展示一下结果(K=3)
'''
# 新建聚类模型,改变聚类中心个数
model = KMeans(n_clusters = 3,random_state = 1)
# 训练模型
model.fit(Xstd)
# 获取kmeans聚类的结果
labels = model.labels_

#绘制下散点图
iris['labels'] = labels

list(set(list(iris['Species']))) # 确定是什么类

# 真实的结果散点图
for label in list(set(list(iris['Species']))):
df = iris[iris['Species'] == label]
plt.plot(df['Sepal.Length'],df['Sepal.Width'],'o',label = label)
plt.legend()
plt.show()

# 聚类的结果散点图
for label in list(set(list(iris['labels']))):
df = iris[iris['labels'] == label]
plt.plot(df['Sepal.Length'],df['Sepal.Width'],'o',label = label)
plt.legend()
plt.show()