【机器学习】 mean-shift聚类算法

mean-shift算法原理

  1. 在,未被标记的数据点中,随机选择一个点作为,起始中心点(记作center)
  2. 找出以center为中心,半径为radius,的区域中出现的所有数据点,认为这些点同属于一个聚类 C 1 C_1 C1同时,这些点增加记录:在类 C 1 C_1 C1中出现的次数加一
  3. 以center为中心点,计算从center到区域中每个元素的向量,将这些,向量求和,得到向量shift
  4. center(向量) = center(向量) + shift(向量)。即center沿着shift的方向移动,移动距离是||shift||。
  5. 重复步骤2、3、4,直到shift的很小(就是迭代到收敛),记住此时的center。
  6. 如果收敛时当前簇 C x C_x Cx的center与其它已经存在的簇 C y C_y Cy中心的距离小于阈值,那么把 C x C_x Cx C y C_y Cy合并,数据点出现次数也对应合并。否则,把 C x C_x Cx作为新的聚类。
  7. 重复1、2、3、4、5,直到所有的点都被标记为已访问(归过类的点,就算已访问)
  8. 分类:这个点被分为,记录出现次数最多的那个类

python实现mean-shift

# -*- coding:utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt
import random
from sklearn.cluster import MeanShift
from sklearn.cluster import estimate_bandwidth


STANDARD_COLORS = [
    'AliceBlue', 'Chartreuse', 'Aqua', 'Aquamarine', 'Azure', 'Beige', 'Bisque',
    'BlanchedAlmond', 'BlueViolet', 'BurlyWood', 'CadetBlue', 'AntiqueWhite',
    'Chocolate', 'Coral', 'CornflowerBlue', 'Cornsilk', 'Crimson', 'Cyan',
    'DarkCyan', 'DarkGoldenRod', 'DarkGrey', 'DarkKhaki', 'DarkOrange',
    'DarkOrchid', 'DarkSalmon', 'DarkSeaGreen', 'DarkTurquoise', 'DarkViolet',
    'DeepPink', 'DeepSkyBlue', 'DodgerBlue', 'FireBrick', 'FloralWhite',
    'ForestGreen', 'Fuchsia', 'Gainsboro', 'GhostWhite', 'Gold', 'GoldenRod',
    'Salmon', 'Tan', 'HoneyDew', 'HotPink', 'IndianRed', 'Ivory', 'Khaki',
    'Lavender', 'LavenderBlush', 'LawnGreen', 'LemonChiffon', 'LightBlue',
    'LightCoral', 'LightCyan', 'LightGoldenRodYellow', 'LightGray', 'LightGrey',
    'LightGreen', 'LightPink', 'LightSalmon', 'LightSeaGreen', 'LightSkyBlue',
    'LightSlateGray', 'LightSlateGrey', 'LightSteelBlue', 'LightYellow', 'Lime',
    'LimeGreen', 'Linen', 'Magenta', 'MediumAquaMarine', 'MediumOrchid',
    'MediumPurple', 'MediumSeaGreen', 'MediumSlateBlue', 'MediumSpringGreen',
    'MediumTurquoise', 'MediumVioletRed', 'MintCream', 'MistyRose', 'Moccasin',
    'NavajoWhite', 'OldLace', 'Olive', 'OliveDrab', 'Orange', 'OrangeRed',
    'Orchid', 'PaleGoldenRod', 'PaleGreen', 'PaleTurquoise', 'PaleVioletRed',
    'PapayaWhip', 'PeachPuff', 'Peru', 'Pink', 'Plum', 'PowderBlue', 'Purple',
    'Red', 'RosyBrown', 'RoyalBlue', 'SaddleBrown', 'Green', 'SandyBrown',
    'SeaGreen', 'SeaShell', 'Sienna', 'Silver', 'SkyBlue', 'SlateBlue',
    'SlateGray', 'SlateGrey', 'Snow', 'SpringGreen', 'SteelBlue', 'GreenYellow',
    'Teal', 'Thistle', 'Tomato', 'Turquoise', 'Violet', 'Wheat', 'White',
    'WhiteSmoke', 'Yellow', 'YellowGreen'
]


def create_the_dataset(size, min, max):
    """
    随机生成数据集

    :param size: 数据集的大小
    :param min: 点坐标的最小值
    :param max: 点坐标的最大值
    :return: [{"data":[x, y], "status": False}, {}, {}, ...]
    """
    dataset = []
    sub_dataset = []
    for i in range(size):
        sub_dataset.append(random.uniform(min, max))
        sub_dataset.append(random.uniform(min, max))
        dataset.append({"data": sub_dataset, "status": False})
        sub_dataset = []

    return dataset


def mean_shift(dataset, radius):
    random.shuffle(dataset)

    # 存放聚类结果, 一个一个的字典
    clusters = []
    while True:
        # 存放没有标记被访问的点, 一个一个的字典
        dataset_list = []
        for each in dataset:
            if each["status"] is False:
                dataset_list.append(each)

        # 当所有点被标记访问后, 跳出循环
        if len(dataset_list) == 0:
            break

        # 随机选取一个点作为center
        center_index = random.randint(0, len(dataset_list)-1)

        cluster_centroid = dataset_list[center_index]  # 字典
        cluster_frequency = np.zeros(len(dataset))

        old_centroid = np.array(cluster_centroid["data"])
        while True:
            # temp_data里存放的是, 以当前center中心, radius范围内所有的点, 包括中心点
            temp_data = []
            for j in range(len(dataset)):
                v = dataset[j]["data"]
                if np.linalg.norm(np.array(v) - np.array(old_centroid)) <= radius:
                    temp_data.append(np.array(v))
                    cluster_frequency[j] += 1
                    dataset[j]["status"] = True

            new_centroid = np.average(temp_data, axis=0)

            if np.array_equal(new_centroid, old_centroid):
                break

            old_centroid = new_centroid

        has_same_cluster = False

        same_cluster = []
        scores = []
        for cluster in clusters:
            if np.linalg.norm(cluster["centroid"] - new_centroid) <= radius:
                has_same_cluster = True
                same_cluster.append(cluster)
                scores.append(np.linalg.norm(cluster["centroid"] - new_centroid))

        if len(same_cluster) != 0:
            same_cluster[scores.index(min(scores))]["frequency"] = \
                same_cluster[scores.index(min(scores))]["frequency"] + cluster_frequency

        if not has_same_cluster:
            clusters.append({
                "centroid": new_centroid,
                "frequency": cluster_frequency
            })

    print("clusters : ", len(clusters))
    clustering(dataset, clusters)
    show_clusters(clusters, radius)


def clustering(data, clusters):
    t = []
    for cluster in clusters:
        cluster["data"] = []
        t.append(cluster["frequency"])

    # 中心个数行, 数据个数列
    t = np.array(t)
    for i in range(len(data)):
        column_frequency = t[:, i]
        cluster_index = np.where(column_frequency == np.max(column_frequency))[0][0]
        clusters[cluster_index]["data"].append(data[i]["data"])


def show_clusters(clusters, radius):
    theta = np.linspace(0, 2 * np.pi, 800)
    for i in range(len(clusters)):
        cluster = clusters[i]
        data = np.array(cluster["data"])
        plt.scatter(data[:, 0], data[:, 1], color=STANDARD_COLORS[i], s=20)

        centroid = cluster["centroid"]
        plt.scatter(centroid[0], centroid[1], color=STANDARD_COLORS[i], marker='x', s=30)

        x, y = np.cos(theta) * radius + centroid[0], np.sin(theta) * radius + centroid[1]
        plt.plot(x, y, linewidth=1, color=STANDARD_COLORS[i])
    plt.savefig("result_1.png")
    plt.show()


X = create_the_dataset(200, -10, 10)
data_list = []
for each in X:
    data_list.append(each["data"])
dataset_ = np.array(data_list)
radius_ = estimate_bandwidth(dataset_, quantile=0.1, n_samples=200)
print("radius : ", radius_)

# 自己实现的mean-shift
mean_shift(X, radius_)

# sklearn中的mean-shift
meanshift = MeanShift(radius_)
result = meanshift.fit_predict(dataset_)

plt.scatter(dataset_[:, 0].tolist(), dataset_[:, 1].tolist(), c=result)
plt.savefig("result_2.png")
plt.show()

自己实现的mean-shift聚类结果:
在这里插入图片描述
skleran中的mean-shift聚类结果:
在这里插入图片描述

结语

如果您有修改意见或问题,欢迎留言或者通过邮箱和我联系。
手打很辛苦,如果我的文章对您有帮助,转载请注明出处。

Logo

汇聚全球AI编程工具,助力开发者即刻编程。

更多推荐