4.1. 介绍¶
前面的实验中,相信你已经对线性回归有了充分的了解。掌握一元和多元线性回归之后,我们就能针对一些有线性分布趋势的数据进行回归预测。但是,生活中还常常会遇到一些分布不那么「线性」的数据,例如像股市的波动、交通流量等。那么对于这类非线性分布的数据,就需要通过本次实验介绍的方法来处理。
4.2. 知识点¶
- 多项式回归介绍
- 多项式回归基础
- 多项式回归预测
4.3. 多项式回归介绍¶
在线性回归中,我们通过建立自变量\(x\) 的一次方程来拟合数据。而非线性回归中,则需要建立因变量和自变量之间的非线性关系。从直观上讲,也就是拟合的直线变成了「曲线」。
对于非线性回归问题而言,最简单也是最常见的方法就是本次实验要讲解的「多项式回归」。多项式是中学时期就会接触到的概念,这里引用维基百科的定义如下:
多项式(Polynomial)是代数学中的基础概念,是由称为未知数的变量和称为系数的常量通过有限次加法、加减法、乘法以及自然数幂次的乘方运算得到的代数表达式。多项式是整式的一种。未知数只有一个的多项式称为一元多项式;例如\(x^2-3x+4\) 就是一个一元多项式。未知数不止一个的多项式称为多元多项式,例如\(x^3-2xyz^2+2yz+1\) 就是一个三元多项式。
4.4. 多项式回归基础¶
首先,我们通过一组示例数据来认识多项式回归
# 加载示例数据
x = [4, 8, 12, 25, 32, 43, 58, 63, 69, 79]
y = [20, 33, 50, 56, 42, 31, 33, 46, 65, 75]
示例数据一共有 10 组,分别对应着横坐标和纵坐标。接下来,通过 Matplotlib 绘制数据,查看其变化趋势。
from matplotlib import pyplot as plt
%matplotlib inline
plt.scatter(x, y)

4.5. 实现 2 次多项式拟合¶
接下来,通过多项式来拟合上面的散点数据。首先,一个标准的一元高阶多项式函数如下所示:
\[ y(x, w) = w_0 + w_1x + w_2x^2 +...+w_mx^m = \sum\limits_{j=0}^{m}w_jx^j \tag{1} \]
其中,\(m\) 表示多项式的阶数,\(x^j\) 表示\(x\) 的\(j\) 次幂,\(w\) 则代表该多项式的系数。
当我们使用上面的多项式去拟合散点时,需要确定两个要素,分别是:多项式系数\(w\) 以及多项式阶数\(m\),这也是多项式的两个基本要素。
如果通过手动指定多项式阶数\(m\) 的大小,那么就只需要确定多项式系数\(w\) 的值是多少。例如,这里首先指定\(m=2\),多项式就变成了:
\[ y(x, w) = w_0 + w_1x + w_2x^2= \sum\limits_{j=0}^{2}w_jx^j \tag{2} \]
当我们确定\(w\) 的值的大小时,就回到了前面线性回归中学习到的内容。
首先,我们构造两个函数,分别是用于拟合的多项式函数,以及误差函数。
def func(p, x):
# 根据公式,定义 2 次多项式函数
w0, w1, w2 = p
f = w0 + w1 * x + w2 * x * x
return f
def err_func(p, x, y):
# 定义误差函数
ret = func(p, x) - y
return ret
接下来,我们使用 scipy.optimize.leastsq 方法来进行最小二乘拟合。
import numpy as np
from scipy.optimize import leastsq
# 定义初始参数
p_init = np.random.randn(3)
# 最小二乘拟合
parameters = leastsq(err_func, p_init, args=(np.array(x), np.array(y)))
print("拟合参数:", parameters[0])
拟合参数: [ 3.76893106e+01 -2.60474058e-01 8.00077975e-03]
接下来,我们绘制拟合结果。
# 绘制原始数据点
plt.scatter(x, y, label="原始数据点")
# 绘制拟合曲线
x_plot = np.linspace(0, 80, 100)
y_plot = func(parameters[0], x_plot)
plt.plot(x_plot, y_plot, 'r-', label="2次多项式拟合")
plt.legend()
plt.title("2次多项式回归拟合结果")
plt.xlabel("x")
plt.ylabel("y")
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 27425 (\N{CJK UNIFIED IDEOGRAPH-6B21}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 22810 (\N{CJK UNIFIED IDEOGRAPH-591A}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 39033 (\N{CJK UNIFIED IDEOGRAPH-9879}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 24335 (\N{CJK UNIFIED IDEOGRAPH-5F0F}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 22238 (\N{CJK UNIFIED IDEOGRAPH-56DE}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 24402 (\N{CJK UNIFIED IDEOGRAPH-5F52}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 25311 (\N{CJK UNIFIED IDEOGRAPH-62DF}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 21512 (\N{CJK UNIFIED IDEOGRAPH-5408}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 32467 (\N{CJK UNIFIED IDEOGRAPH-7ED3}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 26524 (\N{CJK UNIFIED IDEOGRAPH-679C}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 21407 (\N{CJK UNIFIED IDEOGRAPH-539F}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 22987 (\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 25968 (\N{CJK UNIFIED IDEOGRAPH-6570}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 25454 (\N{CJK UNIFIED IDEOGRAPH-636E}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/events.py:82: UserWarning: Glyph 28857 (\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from current font.
func(*args, **kwargs)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 27425 (\N{CJK UNIFIED IDEOGRAPH-6B21}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 22810 (\N{CJK UNIFIED IDEOGRAPH-591A}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 39033 (\N{CJK UNIFIED IDEOGRAPH-9879}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 24335 (\N{CJK UNIFIED IDEOGRAPH-5F0F}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 22238 (\N{CJK UNIFIED IDEOGRAPH-56DE}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 24402 (\N{CJK UNIFIED IDEOGRAPH-5F52}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 25311 (\N{CJK UNIFIED IDEOGRAPH-62DF}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 21512 (\N{CJK UNIFIED IDEOGRAPH-5408}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 32467 (\N{CJK UNIFIED IDEOGRAPH-7ED3}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 26524 (\N{CJK UNIFIED IDEOGRAPH-679C}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 21407 (\N{CJK UNIFIED IDEOGRAPH-539F}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 22987 (\N{CJK UNIFIED IDEOGRAPH-59CB}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 25968 (\N{CJK UNIFIED IDEOGRAPH-6570}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 25454 (\N{CJK UNIFIED IDEOGRAPH-636E}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)
/Users/kevinlou/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/IPython/core/pylabtools.py:170: UserWarning: Glyph 28857 (\N{CJK UNIFIED IDEOGRAPH-70B9}) missing from current font.
fig.canvas.print_figure(bytes_io, **kw)

4.6. 实现 N 次多项式拟合¶
接下来,我们实现一个通用的 N 次多项式拟合函数。
def poly_func(p, x, degree):
"""
通用多项式函数
p: 多项式系数
x: 自变量
degree: 多项式次数
"""
f = np.zeros_like(x)
for i in range(degree + 1):
f += p[i] * (x ** i)
return f
def poly_err_func(p, x, y, degree):
"""
多项式误差函数
"""
ret = poly_func(p, x, degree) - y
return ret
现在,我们使用 3 次多项式来拟合数据。
# 3次多项式拟合
degree = 3
p_init = np.random.randn(degree + 1)
parameters_3 = leastsq(poly_err_func, p_init, args=(np.array(x), np.array(y), degree))
print("3次多项式拟合参数:", parameters_3[0])
---------------------------------------------------------------------------
UFuncTypeError Traceback (most recent call last)
Cell In[7], line 4
2 degree = 3
3 p_init = np.random.randn(degree + 1)
----> 4 parameters_3 = leastsq(poly_err_func, p_init, args=(np.array(x), np.array(y), degree))
5 print("3次多项式拟合参数:", parameters_3[0])
File ~/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/scipy/optimize/_minpack_py.py:415, in leastsq(func, x0, args, Dfun, full_output, col_deriv, ftol, xtol, gtol, maxfev, epsfcn, factor, diag)
413 if not isinstance(args, tuple):
414 args = (args,)
--> 415 shape, dtype = _check_func('leastsq', 'func', func, x0, args, n)
416 m = shape[0]
418 if n > m:
File ~/miniforge3/envs/dsml-py311/lib/python3.11/site-packages/scipy/optimize/_minpack_py.py:25, in _check_func(checker, argname, thefunc, x0, args, numinputs, output_shape)
23 def _check_func(checker, argname, thefunc, x0, args, numinputs,
24 output_shape=None):
---> 25 res = atleast_1d(thefunc(*((x0[:numinputs],) + args)))
26 if (output_shape is not None) and (shape(res) != output_shape):
27 if (output_shape[0] != 1):
Cell In[6], line 18, in poly_err_func(p, x, y, degree)
14 def poly_err_func(p, x, y, degree):
15 """
16 多项式误差函数
17 """
---> 18 ret = poly_func(p, x, degree) - y
19 return ret
Cell In[6], line 10, in poly_func(p, x, degree)
8 f = np.zeros_like(x)
9 for i in range(degree + 1):
---> 10 f += p[i] * (x ** i)
11 return f
UFuncTypeError: Cannot cast ufunc 'add' output from dtype('float64') to dtype('int64') with casting rule 'same_kind'# 绘制3次多项式拟合结果
plt.scatter(x, y, label="原始数据点")
x_plot = np.linspace(0, 80, 100)
y_plot_3 = poly_func(parameters_3[0], x_plot, degree)
plt.plot(x_plot, y_plot_3, 'g-', label="3次多项式拟合")
plt.legend()
plt.title("3次多项式回归拟合结果")
plt.xlabel("x")
plt.ylabel("y")
4.7. 使用 scikit-learn 进行多项式拟合¶
接下来,我们使用 scikit-learn 中的 PolynomialFeatures 和 LinearRegression 来实现多项式回归。
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline
# 使用 scikit-learn 进行2次多项式拟合
model_2 = make_pipeline(PolynomialFeatures(degree=2, include_bias=False), LinearRegression())
model_2.fit(np.array(x).reshape(-1, 1), np.array(y))
# 预测
x_plot_sklearn = np.linspace(0, 80, 100).reshape(-1, 1)
y_pred_2 = model_2.predict(x_plot_sklearn)
# 绘制结果
plt.scatter(x, y, label="原始数据点")
plt.plot(x_plot_sklearn, y_pred_2, 'r-', label="sklearn 2次多项式拟合")
plt.legend()
plt.title("使用 scikit-learn 的2次多项式回归")
plt.xlabel("x")
plt.ylabel("y")
4.8. 多项式回归预测¶
现在,我们使用训练好的模型进行预测。
# 预测新数据点
new_x = [85, 90, 95]
new_x_array = np.array(new_x).reshape(-1, 1)
predictions = model_2.predict(new_x_array)
print("预测结果:")
for i, (x_val, y_pred) in enumerate(zip(new_x, predictions)):
print(f"x = {x_val}, 预测 y = {y_pred:.2f}")
4.9. 线性回归与 2 次多项式回归对比¶
接下来,我们对比线性回归和2次多项式回归的效果。
from sklearn.metrics import mean_squared_error, r2_score
# 线性回归
linear_model = LinearRegression()
linear_model.fit(np.array(x).reshape(-1, 1), np.array(y))
y_pred_linear = linear_model.predict(np.array(x).reshape(-1, 1))
# 2次多项式回归
y_pred_poly = model_2.predict(np.array(x).reshape(-1, 1))
# 计算评估指标
mse_linear = mean_squared_error(y, y_pred_linear)
mse_poly = mean_squared_error(y, y_pred_poly)
r2_linear = r2_score(y, y_pred_linear)
r2_poly = r2_score(y, y_pred_poly)
print("线性回归 MSE:", mse_linear)
print("2次多项式回归 MSE:", mse_poly)
print("线性回归 R²:", r2_linear)
print("2次多项式回归 R²:", r2_poly)
# 可视化对比
plt.scatter(x, y, label="原始数据点", alpha=0.7)
# 线性回归拟合线
x_plot = np.linspace(0, 80, 100).reshape(-1, 1)
y_linear_plot = linear_model.predict(x_plot)
plt.plot(x_plot, y_linear_plot, 'b-', label=f"线性回归 (R²={r2_linear:.3f})")
# 2次多项式回归拟合线
y_poly_plot = model_2.predict(x_plot)
plt.plot(x_plot, y_poly_plot, 'r-', label=f"2次多项式回归 (R²={r2_poly:.3f})")
plt.legend()
plt.title("线性回归 vs 2次多项式回归对比")
plt.xlabel("x")
plt.ylabel("y")
4.10. 更高次多项式回归预测¶
现在,我们尝试使用更高次的多项式进行拟合。
# 尝试不同次数的多项式
degrees = [1, 2, 3, 4, 5]
colors = ['blue', 'red', 'green', 'orange', 'purple']
plt.figure(figsize=(12, 8))
plt.scatter(x, y, label="原始数据点", s=100, alpha=0.7)
x_plot = np.linspace(0, 80, 100).reshape(-1, 1)
for i, degree in enumerate(degrees):
model = make_pipeline(PolynomialFeatures(degree=degree, include_bias=False), LinearRegression())
model.fit(np.array(x).reshape(-1, 1), np.array(y))
y_pred = model.predict(x_plot)
# 计算R²分数
y_pred_train = model.predict(np.array(x).reshape(-1, 1))
r2 = r2_score(y, y_pred_train)
plt.plot(x_plot, y_pred, color=colors[i], label=f"{degree}次多项式 (R²={r2:.3f})")
plt.legend()
plt.title("不同次数多项式回归对比")
plt.xlabel("x")
plt.ylabel("y")
plt.grid(True, alpha=0.3)
4.11. 多项式回归预测次数选择¶
接下来,我们通过交叉验证来选择最优的多项式次数。
from sklearn.model_selection import cross_val_score
from sklearn.metrics import make_scorer
# 使用交叉验证选择最优次数
degrees = range(1, 8)
cv_scores = []
for degree in degrees:
model = make_pipeline(PolynomialFeatures(degree=degree, include_bias=False), LinearRegression())
scores = cross_val_score(model, np.array(x).reshape(-1, 1), np.array(y), cv=5, scoring='neg_mean_squared_error')
cv_scores.append(-scores.mean())
# 绘制交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(degrees, cv_scores, 'bo-', linewidth=2, markersize=8)
plt.xlabel('多项式次数')
plt.ylabel('交叉验证MSE')
plt.title('不同次数多项式的交叉验证MSE')
plt.grid(True, alpha=0.3)
# 找到最优次数
best_degree = degrees[np.argmin(cv_scores)]
print(f"最优多项式次数: {best_degree}")
print(f"对应的交叉验证MSE: {min(cv_scores):.4f}")
4.12. 总结¶
通过本次实验,我们学习了多项式回归的基本概念和实现方法:
多项式回归基础:多项式回归是线性回归的扩展,通过引入高次项来拟合非线性数据。
实现方法:
- 使用 scipy.optimize.leastsq 进行手动拟合
- 使用 scikit-learn 的 PolynomialFeatures 和 LinearRegression
模型评估:通过 MSE、R² 等指标评估模型性能。
次数选择:通过交叉验证选择最优的多项式次数,避免过拟合。
应用场景:多项式回归适用于具有非线性关系的数据,如股市波动、交通流量等。
多项式回归是机器学习中的重要方法,它为处理非线性数据提供了简单而有效的解决方案。在实际应用中,需要根据数据特点选择合适的多项式次数,并通过交叉验证等方法进行模型选择。
4. 多项式回归实现与应用¶
# 加载示例数据
x = [4, 8, 12, 25, 32, 43, 58, 63, 69, 79]
y = [20, 33, 50, 56, 42, 31, 33, 46, 65, 75]import seaborn as sns
from matplotlib import pyplot as plt
%matplotlib inline
# custom_params = {"figure.figsize": (6, 4),
# "font.sans-serif":"Arial Unicode MS",
# 'axes.unicode_minus': False}
# sns.set_theme(style="ticks", font_scale=0.7, rc=custom_params)
sns.set_style("ticks")
plt.scatter(x, y)
import pandas as pd
url = 'https://raw.githubusercontent.com/GeostatsGuy/GeoDataSets/master/1D_Porosity.csv'
df = pd.read_csv(url)
df.head()# 定义 x, y 的取值
x = df["Depth"]
y = df["Nporosity"]
# 绘图
plt.plot(x, y, "r")
plt.scatter(x, y)
# 首先划分 dateframe 为训练集和测试集
train_df = df[: int(len(df) * 0.7)]
test_df = df[int(len(df) * 0.7) :]
# 定义训练和测试使用的自变量和因变量
X_train = train_df["Depth"].values
y_train = train_df["Nporosity"].values
X_test = test_df["Depth"].values
y_test = test_df["Nporosity"].valuesfrom sklearn.linear_model import LinearRegression
# 建立线性回归模型
model = LinearRegression()
model.fit(X_train.reshape(len(X_train), 1), y_train.reshape(len(y_train), 1))
results = model.predict(X_test.reshape(len(X_test), 1))
results # 线性回归模型在测试集上的预测结果array([[0.26222222],
[0.29235085],
[0.32247947],
[0.3526081 ],
[0.38273673],
[0.41286535],
[0.44299398],
[0.47312261],
[0.50325123],
[0.53337986],
[0.56350848],
[0.59363711]])from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
print("线性回归平均绝对误差: ", mean_absolute_error(y_test, results.flatten()))
print("线性回归均方误差: ", mean_squared_error(y_test, results.flatten()))线性回归平均绝对误差: 0.668851943076081
线性回归均方误差: 0.7262482710699257
from sklearn.preprocessing import PolynomialFeatures
# 2 次多项式回归特征矩阵
poly_features_2 = PolynomialFeatures(degree=2, include_bias=False)
poly_X_train_2 = poly_features_2.fit_transform(X_train.reshape(len(X_train), 1))
poly_X_test_2 = poly_features_2.fit_transform(X_test.reshape(len(X_test), 1))
# 2 次多项式回归模型训练与预测
model = LinearRegression()
model.fit(poly_X_train_2, y_train.reshape(len(X_train), 1)) # 训练模型
results_2 = model.predict(poly_X_test_2) # 预测结果
results_2.flatten() # 打印扁平化后的预测结果array([-1.41682234, -1.73408225, -2.07450139, -2.43807977, -2.82481738,
-3.23471422, -3.6677703 , -4.12398562, -4.60336017, -5.10589396,
-5.63158699, -6.18043924])print("2 次多项式回归平均绝对误差: ", mean_absolute_error(y_test, results_2.flatten()))
print("2 次多项式均方误差: ", mean_squared_error(y_test, results_2.flatten()))2 次多项式回归平均绝对误差: 4.068004469987231
2 次多项式均方误差: 20.936635430648945
from sklearn.pipeline import make_pipeline
X_train = X_train.reshape(len(X_train), 1)
X_test = X_test.reshape(len(X_test), 1)
y_train = y_train.reshape(len(y_train), 1)
for m in [3, 4, 5]:
model = make_pipeline(PolynomialFeatures(m, include_bias=False), LinearRegression())
model.fit(X_train, y_train)
pre_y = model.predict(X_test)
print("{} 次多项式回归平均绝对误差: ".format(m), mean_absolute_error(y_test, pre_y.flatten()))
print("{} 次多项式均方误差: ".format(m), mean_squared_error(y_test, pre_y.flatten()))
print("---")3 次多项式回归平均绝对误差: 5.129257302352646
3 次多项式均方误差: 33.531553828627324
---
4 次多项式回归平均绝对误差: 2.3082420473069822
4 次多项式均方误差: 6.639506434094571
---
5 次多项式回归平均绝对误差: 21.562146337862032
5 次多项式均方误差: 951.7450539295969
---
4.11. 多项式回归预测次数选择¶
# 计算 m 次多项式回归预测结果的 MSE 评价指标并绘图
mse = [] # 用于存储各最高次多项式 MSE 值
m = 1 # 初始 m 值
m_max = 10 # 设定最高次数
while m <= m_max:
model = make_pipeline(PolynomialFeatures(m, include_bias=False), LinearRegression())
model.fit(X_train, y_train) # 训练模型
pre_y = model.predict(X_test) # 测试模型
mse.append(mean_squared_error(y_test, pre_y.flatten())) # 计算 MSE
m = m + 1
print("MSE 计算结果: ", mse)
# 绘图
plt.plot([i for i in range(1, m_max + 1)], mse, "r")
plt.scatter([i for i in range(1, m_max + 1)], mse)
# 绘制图名称等
plt.title("MSE of m degree of polynomial regression")
plt.xlabel("m")
plt.ylabel("MSE")MSE 计算结果: [0.7262482710699257, 20.936635430648945, 33.531553828627324, 6.639506434094571, 951.7450539295969, 5951.756389947962, 164169.2648002583, 1124239.210057049, 8501435.234869642, 125938566.68847854]
