【机器学习】 mean-shift聚类算法
1.mean-shift聚类算法的原理;2.python实现的mean-shift聚类算法;3.sklearn中的mean-shift聚类算法
·
【机器学习】 mean-shift聚类算法
mean-shift算法原理
- 在,未被标记的数据点中,随机选择一个点作为,起始中心点(记作center)
- 找出以center为中心,半径为radius,的区域中出现的所有数据点,认为这些点同属于一个聚类 C 1 C_1 C1。同时,这些点增加记录:在类 C 1 C_1 C1中出现的次数加一
- 以center为中心点,计算从center到区域中每个元素的向量,将这些,向量求和,得到向量shift
- center(向量) = center(向量) + shift(向量)。即center沿着shift的方向移动,移动距离是||shift||。
- 重复步骤2、3、4,直到shift的很小(就是迭代到收敛),记住此时的center。
- 如果收敛时当前簇 C x C_x Cx的center与其它已经存在的簇 C y C_y Cy中心的距离小于阈值,那么把 C x C_x Cx和 C y C_y Cy合并,数据点出现次数也对应合并。否则,把 C x C_x Cx作为新的聚类。
- 重复1、2、3、4、5,直到所有的点都被标记为已访问(归过类的点,就算已访问)
- 分类:这个点被分为,记录出现次数最多的那个类。
python实现mean-shift
# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import random
from sklearn.cluster import MeanShift
from sklearn.cluster import estimate_bandwidth
STANDARD_COLORS = [
'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
'WhiteSmoke', 'Yellow', 'YellowGreen'
]
def create_the_dataset(size, min, max):
"""
随机生成数据集
:param size: 数据集的大小
:param min: 点坐标的最小值
:param max: 点坐标的最大值
:return: [{"data":[x, y], "status": False}, {}, {}, ...]
"""
dataset = []
sub_dataset = []
for i in range(size):
sub_dataset.append(random.uniform(min, max))
sub_dataset.append(random.uniform(min, max))
dataset.append({"data": sub_dataset, "status": False})
sub_dataset = []
return dataset
def mean_shift(dataset, radius):
random.shuffle(dataset)
# 存放聚类结果, 一个一个的字典
clusters = []
while True:
# 存放没有标记被访问的点, 一个一个的字典
dataset_list = []
for each in dataset:
if each["status"] is False:
dataset_list.append(each)
# 当所有点被标记访问后, 跳出循环
if len(dataset_list) == 0:
break
# 随机选取一个点作为center
center_index = random.randint(0, len(dataset_list)-1)
cluster_centroid = dataset_list[center_index] # 字典
cluster_frequency = np.zeros(len(dataset))
old_centroid = np.array(cluster_centroid["data"])
while True:
# temp_data里存放的是, 以当前center中心, radius范围内所有的点, 包括中心点
temp_data = []
for j in range(len(dataset)):
v = dataset[j]["data"]
if np.linalg.norm(np.array(v) - np.array(old_centroid)) <= radius:
temp_data.append(np.array(v))
cluster_frequency[j] += 1
dataset[j]["status"] = True
new_centroid = np.average(temp_data, axis=0)
if np.array_equal(new_centroid, old_centroid):
break
old_centroid = new_centroid
has_same_cluster = False
same_cluster = []
scores = []
for cluster in clusters:
if np.linalg.norm(cluster["centroid"] - new_centroid) <= radius:
has_same_cluster = True
same_cluster.append(cluster)
scores.append(np.linalg.norm(cluster["centroid"] - new_centroid))
if len(same_cluster) != 0:
same_cluster[scores.index(min(scores))]["frequency"] = \
same_cluster[scores.index(min(scores))]["frequency"] + cluster_frequency
if not has_same_cluster:
clusters.append({
"centroid": new_centroid,
"frequency": cluster_frequency
})
print("clusters : ", len(clusters))
clustering(dataset, clusters)
show_clusters(clusters, radius)
def clustering(data, clusters):
t = []
for cluster in clusters:
cluster["data"] = []
t.append(cluster["frequency"])
# 中心个数行, 数据个数列
t = np.array(t)
for i in range(len(data)):
column_frequency = t[:, i]
cluster_index = np.where(column_frequency == np.max(column_frequency))[0][0]
clusters[cluster_index]["data"].append(data[i]["data"])
def show_clusters(clusters, radius):
theta = np.linspace(0, 2 * np.pi, 800)
for i in range(len(clusters)):
cluster = clusters[i]
data = np.array(cluster["data"])
plt.scatter(data[:, 0], data[:, 1], color=STANDARD_COLORS[i], s=20)
centroid = cluster["centroid"]
plt.scatter(centroid[0], centroid[1], color=STANDARD_COLORS[i], marker='x', s=30)
x, y = np.cos(theta) * radius + centroid[0], np.sin(theta) * radius + centroid[1]
plt.plot(x, y, linewidth=1, color=STANDARD_COLORS[i])
plt.savefig("result_1.png")
plt.show()
X = create_the_dataset(200, -10, 10)
data_list = []
for each in X:
data_list.append(each["data"])
dataset_ = np.array(data_list)
radius_ = estimate_bandwidth(dataset_, quantile=0.1, n_samples=200)
print("radius : ", radius_)
# 自己实现的mean-shift
mean_shift(X, radius_)
# sklearn中的mean-shift
meanshift = MeanShift(radius_)
result = meanshift.fit_predict(dataset_)
plt.scatter(dataset_[:, 0].tolist(), dataset_[:, 1].tolist(), c=result)
plt.savefig("result_2.png")
plt.show()
自己实现的mean-shift聚类结果:
skleran中的mean-shift聚类结果:
结语
如果您有修改意见或问题,欢迎留言或者通过邮箱和我联系。
手打很辛苦,如果我的文章对您有帮助,转载请注明出处。
更多推荐
所有评论(0)