python机器学习(网络搜索和交叉验证)

news/2024/10/16 7:10:24 标签: python, 机器学习, 人工智能

"""
网格搜索:
    指的是 GridSearchCV这个工具的功能, 可以帮助我们寻找最优的 超参数.

    超参数解释:
        在机器学习中, 我们把需要用户手动传入的参数称之为: 超参数.

交叉验证:
    指的是对数据集进行划分, 即: 把数据分成N份进行验证
        第1次: 第1份是 验证集(测试集), 其它是训练集, 得出: 模型的预估正确率
        第2次: 第2份是 验证集(测试集), 其它是训练集, 得出: 模型的预估正确率
        ......
        第n次: 第n份是 验证集(测试集), 其它是训练集, 得出: 模型的预估正确率
    最终计算所有正确率的平均值, 结合网格搜索, 找到最优参数组合.

回顾: 机器学习的建模流程.
    1. 加载数据.
    2. 数据预处理.
    3. 特征工程.
    4. 模型训练.
    5. 模型预测.
    6. 模型评估.
"""

# 导包
from sklearn.datasets import load_iris                  # 加载鸢尾花测试集的.
from sklearn.model_selection import train_test_split, GridSearchCV    # 分割训练集和测试集的, 网格搜索 => 找最优参数组合
from sklearn.preprocessing import StandardScaler        # 数据标准化的
from sklearn.neighbors import KNeighborsClassifier      # KNN算法 分类对象
from sklearn.metrics import accuracy_score              # 模型评估的, 计算模型预测的准确率

# 1. 加载数据.
iris_data = load_iris()
# 2. 数据预处理, 即: 划分 训练集, 测试集.
x_train, x_test, y_train, y_test = train_test_split(iris_data.data, iris_data.target, test_size=0.2, random_state=21)
# 3. 特征工程, 即: 特征的预处理 => 数据的标准化.
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)   # 训练 + 转换 => 适用于: 训练集.
x_test = transfer.transform(x_test)         # 直接转换 => 适用于: 测试集.

# 4. 模型训练.
# 4.1 创建 估计器对象.
estimator = KNeighborsClassifier()
# 4.2 定义网格搜索的参数, 即: 样本可能存在的参数组合值 => 超参数.
param_dict = {'n_neighbors': [1, 2, 3, 5, 7]}
# 4.3 创建网格搜索对象, 帮我们找到最优的参数组合.
# 参1: 估计器对象, 传入1个估计器对象, 网格搜索后, 会自动返回1个功能更加强大(最优参数)的 估计器对象.
# 参2: 网格搜索的参数, 传入1个字典, 键: 参数名, 值: 参数值列表.
# 参3: 交叉验证的次数, 指定值为: 4
estimator = GridSearchCV(estimator, param_dict, cv=5)
# 4.4 调用 估计器对象的 fit方法, 完成模型训练.
estimator.fit(x_train, y_train)

# 4.5 查看网格搜索后的参数
print(f'最优组合平均分: {estimator.best_score_}')
print(f'最优估计器对象: {estimator.best_estimator_}')  # 3
print(f'具体的验证过程: {estimator.cv_results_}')
print(f'最优的参数: {estimator.best_params_}')


# 5. 得到超参数最优值之后, 再次对模型进行训练.
estimator = KNeighborsClassifier(n_neighbors=3)
# 模型训练
estimator.fit(x_train, y_train)
# 模型评估
print(estimator.score(x_test, y_test))  # 0.9666666666666667


http://www.niftyadmin.cn/n/5707531.html

相关文章

【Python语言进阶(二)】

一、函数的使用方式 将函数视为“一等公民” 函数可以赋值给变量函数可以作为函数的参数函数可以作为函数的返回值 高阶函数的用法(filter、map以及它们的替代品) items1 list(map(lambda x: x ** 2, filter(lambda x: x % 2, range(1, 10)))) # filter…

《诺贝尔物理学奖一百年》一书

郭奕玲、沈慧君编著,上海科学普及出版社2002年出版,我当时买到一本。全书400多页,讲解了从1901年至2001年所有物理诺奖得主的生平以及代表工作。 (1)高中物理课本是很厚的一本,当年我参加高考,据…

Java 二分查找算法详解及通用实现模板案例示范

1. 引言 二分查找(Binary Search)是一种常见的搜索算法,专门用于在有序数组或列表中查找元素的位置。它通过每次将搜索空间缩小一半,从而极大地提高了查找效率。相比于线性查找算法,二分查找的时间复杂度为 O(log n)&…

PyTorch单机多卡训练(无废话)

目前大家基本都在使用DistributedDataParallel(简称DDP)用来训练,该方法主要用于分布式训练,但也可以用在单机多卡。 第一步:初始化分布式环境,主要用来帮助进程间通信 torch.distributed.init_process_g…

【五】架构设计之接口幂等概述

架构设计之接口幂等 概述 在进行架构设计的过程中我们时常需要考虑接口幂等的实现方案,本文将梳理接口幂等相关的知识点,并且通过一个示例来进行讲解说明接口幂等的实现方案,实现接口幂等的方式有很多,通过本文我们可以整体了解到…

### 更新数据库时出错。原因:java.sql.SQLException: No database selected

更新数据库时出错。原因:java.sql.SQLException: No database selected 问题:原因:解决办法: 问题: 在基于idea环境中学习搭建mybatis框架时,MySQL数据库执行插入语句遇到以下异常: com.intel…

Excel:vba实现身份信息填写

实现的效果是“点击一键填写性别和年龄”,表的呈现如下: 代码如下: Sub 判断性别年龄()Dim idCard As StringDim birthDate As StringDim nian As Integer, yue As Integer, ri As IntegerDim currentDate As DateDim age As IntegerDim ge…

会讲故事的I2C通信时序

前言: 相信各位在学习STM32时候的I2C通信肯定特别苦恼吧,这是什么通信时序,为什么起始终止发送SCL和SDA要那样,即使是深刻学习理解了一遍,时间长了之后也容易忘记,因为记的都是概念性的东西,枯燥…