非线性模型
大家好,我是风风。上一次我们看了机器学习中三种基于线性模型的特征筛选方法,今天我们来聊一聊非线性模型多重共线性的问题。
线性模型应该是我们学过的最简单的机器学习模型了,简单的内容例如我们初中学过的一元线性回归,较难的例如我们在生信文章中经常看到的各种score,都是属于线性模型的范畴。
当我们做多了线性模型,或者线性模型的结果不好解释,亦或者线性模型的结果不符合我们预期的时候,我们往往会产生一种疑问:数据之间的关系就一定是线性的吗?不一定吧!数据之间的关系应该可以是线性相关,也可以是非线性相关才对。
有了这个疑问,势必会产生第二个问题:如果数据是非线性相关,那应该用什么样的模型呢?
等到掌握了非线性相关的模型之后,随之而来的也就是最后一个终极问题:非线性相关模型的模型价值难道就一定比线性相关模型的模型价值高吗?有没什么方法或者指标可以比较这些模型之间的优劣呢?
 带着这些问题,我们来一起看看今天的内容:
非线性模型
理想状态下,我们一般希望结局变量和自变量之间的关系可以是线性相关关系,这样在模型利用的时候更加方便快捷,就像武侠小说里的招式一样——大道至简,越简单的招式往往破绽就越少。但是实际上很多时候,因变量和自变量之间的真实关系更是非线性相关的关系,因此我们就需要有非线性模型。常用的非线性模型有3种:多项式回归、样条回归和广义相加模型GAM。接下来,我们将使用同一份数据构建不同模型,并对比一下这三种模型和线性模型之间的优劣,以选择最适合分析数据的模型用于文章发表。不同模型之间的简单比较可以使用RMSE值和R2值。RMSE值表示模型的预测误差,也就是观察结果值(实际结果)与预测结果值之间的平均差值;而R2值则表示观察结果值和预测结果值之间的相关性的平方,一般我们认为:RMSE值越低、R2值越高的模型,是更好的模型。
线性回归模型
既然是要比较线性模型和非线性模型之间的优劣,我们首先要对数据构建线性回归模型以作为比较对象,线性模型在前面的推文中有专门讲过,大家可以返回看前面的推文,这里进行简单构建:
rm(list = ls())# 安装R包# BiocManager::install("tidyverse")# BiocManager::install("caret")# BiocManager::install("splines")# BiocManager::install("mgcv")# BiocManager::install("MASS")# 加载R包library(tidyverse)## -- Attaching packages --------------------------------------- tidyverse 1.3.1 --## v ggplot2 3.3.3 v purrr 0.3.4## v tibble 3.1.2 v dplyr 1.0.6## v tidyr 1.1.3 v stringr 1.4.0## v readr 1.4.0 v forcats 0.5.1## -- Conflicts ------------------------------------------ tidyverse_conflicts() --## x dplyr::filter() masks stats::filter()## x dplyr::lag() masks stats::lag()library(caret)## 载入需要的程辑包:lattice## ## 载入程辑包:'caret'## The following object is masked from 'package:purrr':## ## liftlibrary(splines)library(mgcv)## 载入需要的程辑包:nlme## ## 载入程辑包:'nlme'## The following object is masked from 'package:dplyr':## ## collapse## This is mgcv 1.8-35. For overview type 'help("mgcv-package")'.library(MASS)## ## 载入程辑包:'MASS'## The following object is masked from 'package:dplyr':## ## select# 我们将使用MASS包中自带的波士顿数据集进行模型构建,为了比较不同模型之间的优劣,所有的模型和分析都使用这一份数据data("Boston", package = "MASS")str(Boston)## 'data.frame': 506 obs. of 14 variables:## $ crim : num 0.00632 0.02731 0.02729 0.03237 0.06905 ...## $ zn : num 18 0 0 0 0 0 12.5 12.5 12.5 12.5 ...## $ indus : num 2.31 7.07 7.07 2.18 2.18 2.18 7.87 7.87 7.87 7.87 ...## $ chas : int 0 0 0 0 0 0 0 0 0 0 ...## $ nox : num 0.538 0.469 0.469 0.458 0.458 0.458 0.524 0.524 0.524 0.524 ...## $ rm : num 6.58 6.42 7.18 7 7.15 ...## $ age : num 65.2 78.9 61.1 45.8 54.2 58.7 66.6 96.1 100 85.9 ...## $ dis : num 4.09 4.97 4.97 6.06 6.06 ...## $ rad : int 1 2 2 3 3 3 5 5 5 5 ...## $ tax : num 296 242 242 222 222 222 311 311 311 311 ...## $ ptratio: num 15.3 17.8 17.8 18.7 18.7 18.7 15.2 15.2 15.2 15.2 ...## $ black : num 397 397 393 395 397 ...## $ lstat : num 4.98 9.14 4.03 2.94 5.33 ...## $ medv : num 24 21.6 34.7 33.4 36.2 28.7 22.9 27.1 16.5 18.9 ...head(Boston)## crim zn indus chas nox rm age dis rad tax ptratio black lstat## 1 0.00632 18 2.31 0 0.538 6.575 65.2 4.0900 1 296 15.3 396.90 4.98## 2 0.02731 0 7.07 0 0.469 6.421 78.9 4.9671 2 242 17.8 396.90 9.14## 3 0.02729 0 7.07 0 0.469 7.185 61.1 4.9671 2 242 17.8 392.83 4.03## 4 0.03237 0 2.18 0 0.458 6.998 45.8 6.0622 3 222 18.7 394.63 2.94## 5 0.06905 0 2.18 0 0.458 7.147 54.2 6.0622 3 222 18.7 396.90 5.33## 6 0.02985 0 2.18 0 0.458 6.430 58.7 6.0622 3 222 18.7 394.12 5.21## medv## 1 24.0## 2 21.6## 3 34.7## 4 33.4## 5 36.2## 6 28.7# 将随机将数据分成训练集(80%用于构建预测模型)和验证集(20%用于评估模型)set.seed(123456) # 设置随机种子,方便重复caret包随机切分数据的结果training_samples <- Boston$medv %>% # 按照medv值进行随机切割 createDataPartition(p = 0.8, # 切割比例为0.8 times = 1, list = FALSE)train_data <- Boston[training_samples, ] # 训练集数据test_data <- Boston[- training_samples, ] # 验证集数据# 我们将使用模型探索medv和lstat两个变量之间的关系,在建模之前,先简单探索一下两个变量之间的关系(训练集)ggplot(data = train_data, aes(x = lstat, y = medv) ) + geom_point(shape = 21, fill = "steelblue", colour = "black", size = 4) + stat_smooth(fill = "gray70", color = "red") + ggtitle("Correlation between lstat and medv") + theme(plot.title = element_text(hjust = 0.50, color="black", size = 15))## `geom_smooth()` using method = 'loess' and formula 'y ~ x'
# 从散点图的结果可以发现lstat和medv两个变量之间存在非线性关系# 接下来我们先构建线性模型,按照我们前面讲过的线性模型的公式,以medv为因变量,lstat为自变量,那么标准线性回归模型方程应该是: medv = b0 + b1×lstat,接下来我们计算线性回归模型:# 使用训练集构建模型model <- lm(medv ~ lstat, data = train_data)model## ## Call:## lm(formula = medv ~ lstat, data = train_data)## ## Coefficients:## (Intercept) lstat ## 34.3155 -0.9381# 使用验证集验证模型predictions <- model %>% predict(test_data)# 计算RMSE值和R2值value1 <- data.frame( RMSE = RMSE(predictions, test_data$medv), R2 = R2(predictions, test_data$medv))value1## RMSE R2## 1 6.681979 0.5384077# 模型结果可视化ggplot(data = train_data, aes(x = lstat, y = medv) ) + geom_point(shape = 21, fill = "steelblue", colour = "black", size = 4) + stat_smooth(method = lm, formula = y ~ x, fill = "gray70", color = "red") + ggtitle("Correlation between lstat and medv") + theme(plot.title = element_text(hjust = 0.50, color="black", size = 15))
# 有了线性回归模型的结果,接下来我们构建非线性回归模型
多项式回归
多项式回归,听起来好像很高大上,但是如果举个例子的话,你马上就会觉得不过如此了,还记得以前学过的一元二次方程ax²+bx+c=0吗?这个就是最简单的多项式回归方程,原来自变量是x,现在自变量是x的多次项,就是多项式回归,对应到这里的例子,那么方程就应该是medv = b0 + b1 × lstat + b2 × lstat2,接下来我们看看实现过程:
# 多项式回归方程# 如果把x2看成是自变量,那么多项式回归其实也可以看成是线性回归,这里也是使用lm,这里提供代码的两种写法# 方法一lm(medv ~ lstat + I(lstat^2), data = train_data)## ## Call:## lm(formula = medv ~ lstat + I(lstat^2), data = train_data)## ## Coefficients:## (Intercept) lstat I(lstat^2) ## 42.38921 -2.29038 0.04284# 方法二lm(medv ~ poly(lstat, 2, raw = TRUE), data = train_data) %>% summary()## ## Call:## lm(formula = medv ~ poly(lstat, 2, raw = TRUE), data = train_data)## ## Residuals:## Min 1Q Median 3Q Max ## -15.101 -3.610 -0.311 2.277 24.572 ## ## Coefficients:## Estimate Std. Error t value Pr(>|t|) ## (Intercept) 42.389206 0.940410 45.08 <2e-16 ***## poly(lstat, 2, raw = TRUE)1 -2.290376 0.133691 -17.13 <2e-16 ***## poly(lstat, 2, raw = TRUE)2 0.042835 0.004064 10.54 <2e-16 ***## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## Residual standard error: 5.407 on 404 degrees of freedom## Multiple R-squared: 0.6444, Adjusted R-squared: 0.6427 ## F-statistic: 366.1 on 2 and 404 DF, p-value: < 2.2e-16# 上面例子用的是2次项,如果是多次项的话,只需要修改一下lm(medv ~ poly(lstat, 6, raw = TRUE), data = train_data) %>% summary()## ## Call:## lm(formula = medv ~ poly(lstat, 6, raw = TRUE), data = train_data)## ## Residuals:## Min 1Q Median 3Q Max ## -13.8787 -3.0730 -0.5208 2.0010 26.2327 ## ## Coefficients:## Estimate Std. Error t value Pr(>|t|) ## (Intercept) 7.227e+01 5.861e+00 12.331 < 2e-16 ***## poly(lstat, 6, raw = TRUE)1 -1.531e+01 3.121e+00 -4.907 1.35e-06 ***## poly(lstat, 6, raw = TRUE)2 1.994e+00 6.032e-01 3.306 0.00103 ** ## poly(lstat, 6, raw = TRUE)3 -1.367e-01 5.511e-02 -2.481 0.01351 * ## poly(lstat, 6, raw = TRUE)4 4.894e-03 2.557e-03 1.914 0.05638 . ## poly(lstat, 6, raw = TRUE)5 -8.648e-05 5.809e-05 -1.489 0.13734 ## poly(lstat, 6, raw = TRUE)6 5.948e-07 5.109e-07 1.164 0.24508 ## ---## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1## ## Residual standard error: 5.111 on 400 degrees of freedom## Multiple R-squared: 0.6854, Adjusted R-squared: 0.6807 ## F-statistic: 145.3 on 6 and 400 DF, p-value: < 2.2e-16# 由上面的6次项分析结果的p值可以看出,超过三阶的多项式项结果并不显著。那么,只需创建第3个多项式回归模型即可,如下:# 使用训练集构建模型model <- lm(medv ~ poly(lstat, 3, raw = TRUE), data = train_data)model## ## Call:## lm(formula = medv ~ poly(lstat, 3, raw = TRUE), data = train_data)## ## Coefficients:## (Intercept) poly(lstat, 3, raw = TRUE)1 ## 47.522040 -3.654410 ## poly(lstat, 3, raw = TRUE)2 poly(lstat, 3, raw = TRUE)3 ## 0.136304 -0.001775# 使用验证集验证模型predictions <- model %>% predict(test_data)# 计算RMSE值和R2值value2 <- data.frame( RMSE = RMSE(predictions, test_data$medv), R2 = R2(predictions, test_data$medv))value2## RMSE R2## 1 5.793502 0.6621416# 模型结果可视化ggplot(data = train_data, aes(x = lstat, y = medv) ) + geom_point(shape = 21, fill = "steelblue", colour = "black", size = 4) + stat_smooth(method = lm, formula = y ~ poly(x, 5, raw = TRUE), fill = "gray70", color = "red") + ggtitle("Correlation between lstat and medv") + theme(plot.title = element_text(hjust = 0.50, color="black", size = 15))
样条回归分析
什么是样条回归分析呢?
多项式回归只能够获得非线性关系中一定数量的曲率,也就是相关性,而另一种非线性回归模型的构建方法就是样条回归,样条回归提供了一种在固定点之间平滑插值的方法,称为结点,然后在结点之间计算多项式回归。换句话说,样条回归也可以看作是一系列的多项式回归串在一起。接下来我们看看样条回归在R中的实现过程:
# 计算节点knots <- quantile(train_data$lstat, p = c(0.25, 0.5, 0.75))knots## 25% 50% 75% ## 6.965 11.320 17.095# 使用训练集构建模型knots <- quantile(train_data$lstat, p = c(0.25, 0.5, 0.75))model <- lm (medv ~ bs(lstat, knots = knots), data = train_data)model## ## Call:## lm(formula = medv ~ bs(lstat, knots = knots), data = train_data)## ## Coefficients:## (Intercept) bs(lstat, knots = knots)1 ## 49.57 -13.19 ## bs(lstat, knots = knots)2 bs(lstat, knots = knots)3 ## -26.03 -26.91 ## bs(lstat, knots = knots)4 bs(lstat, knots = knots)5 ## -40.29 -37.03 ## bs(lstat, knots = knots)6 ## -37.50# 使用验证集验证模型predictions <- model %>% predict(test_data)# 计算RMSE值和R2值value3 <- data.frame( RMSE = RMSE(predictions, test_data$medv), R2 = R2(predictions, test_data$medv))# 模型结果可视化ggplot(data = train_data, aes(x = lstat, y = medv) ) + geom_point(shape = 21, fill = "steelblue", colour = "black", size = 4) + stat_smooth(method = lm, formula = y ~ splines::bs(x, df = 3), fill = "gray70", color = "red") + ggtitle("Correlation between lstat and medv") + theme(plot.title = element_text(hjust = 0.50, color="black", size = 15))
# 由可视化结果可以看出,样条回归拟合的曲线相对多项式回归来说比较平滑
广义相加模型(Generalized additive models)
前面我们说了,一旦我们发现数据之间呈现的关系是非线性关系,样条回归需要计算这种非线性关系中的节点,然而到底多少个节点最合适这个我们是不知道的,那有没有一种可以自动拟合样条回归的方法或者说模型呢?当然有,那就是广义相加模型GAM。我们使用mgcv包构建广义相加模型:
# 使用训练集构建模型model <- gam(medv ~ s(lstat), data = train_data)model## ## Family: gaussian ## Link function: identity ## ## Formula:## medv ~ s(lstat)## ## Estimated degrees of freedom:## 7.24 total = 8.24 ## ## GCV score: 26.5689# 使用验证集验证模型predictions <- model %>% predict(test_data)# 计算RMSE值和R2值value4 <- data.frame( RMSE = RMSE(predictions, test_data$medv), R2 = R2(predictions, test_data$medv))# 模型结果可视化ggplot(data = train_data, aes(x = lstat, y = medv) ) + geom_point(shape = 21, fill = "steelblue", colour = "black", size = 4) + stat_smooth(method = gam, formula = y ~ s(x), fill = "gray70", color = "red") + ggtitle("Correlation between lstat and medv") + theme(plot.title = element_text(hjust = 0.50, color="black", size = 15))
# 最后,我们再比较下4个模型之间的优劣,前面我们已经分别计算了每个模型的RMSE值和R2值value1; value2; value3; value4## RMSE R2## 1 6.681979 0.5384077## RMSE R2## 1 5.793502 0.6621416## RMSE R2## 1 5.628846 0.6879317## RMSE R2## 1 5.632479 0.6870953# 从结果看,3个非线性模型的结果都比线性模型的结果好,这也符合我们一开始初步探索的数据特征
好了,本期的内容到此结束!在本期中,我们一起探讨了非线性模型的三种方法,包括:多项式回归、样条回归和广义相加模型GAM,并分别计算了每个模型的RMSE值和R2值,来比较各个模型之间的优劣性。其实我们在构建模型的时候,很少能够一次性构建出符合我们预期的模型,也可能构建出来的模型结果并不好,我们可以多尝试几种不同的模型,比较各种模型的优劣,看看哪种模型最适合我们的数据,以此来阐述数据之间的特征并发表论文
主页回复“feng 58”获取代码和数据,我们下期见!
单细胞分析专栏传送门
碎碎念专栏传送门(完结)
风之美图系列传送门(完结)
END

撰文丨风   风
排版丨四金兄
主编丨小雪球
欢迎大家关注解螺旋生信频道-挑圈联靠公号~

继续阅读
阅读原文