当前位置:网站首页>FE01_OneHot-Scala Application
FE01_OneHot-Scala Application
2022-08-04 18:00:00 【51CTO】
OneHotIt is a common processing method for dealing with categorical variables,scalaIf it is applied?If it appears in the test set and not in the training setvalue,怎么处理?
1 数据构造
import
org.
apache.
spark.
ml.{
Model,
Pipeline,
PipelineModel,
PipelineStage}
import
org.
apache.
spark.
ml.
classification.
LogisticRegression
import
org.
apache.
spark.
ml.
feature.{
OneHotEncoderEstimator,
StringIndexer}
import
org.
apache.
spark.
sql.
functions.
_
import
org.
apache.
spark.
ml.
feature.
VectorAssembler
import
org.
apache.
spark.
sql.{
DataFrame,
Row,
SparkSession}
val
builder
=
SparkSession
.
builder()
.
appName(
"LR")
.
config(
"spark.executor.heartbeatInterval",
"60s")
.
config(
"spark.network.timeout",
"120s")
.
config(
"spark.serializer",
"org.apache.spark.serializer.KryoSerializer")
.
config(
"spark.kryoserializer.buffer.max",
"512m")
.
config(
"spark.dynamicAllocation.enabled",
false)
.
config(
"spark.sql.inMemoryColumnarStorage.compressed",
true)
.
config(
"spark.sql.inMemoryColumnarStorage.batchSize",
10000)
.
config(
"spark.sql.broadcastTimeout",
600)
.
config(
"spark.sql.autoBroadcastJoinThreshold",
-
1)
.
config(
"spark.sql.crossJoin.enabled",
true)
.
master(
"local[*]")
val
spark
=
builder.
getOrCreate()
spark.
sparkContext.
setLogLevel(
"ERROR")
import
spark.
implicits.
_
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
import org.apache.spark.ml.{Model, Pipeline, PipelineModel, PipelineStage}
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{OneHotEncoderEstimator, StringIndexer}
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
builder: org.apache.spark.sql.SparkSession.Builder = [email protected]
spark: org.apache.spark.sql.SparkSession = [email protected]
import spark.implicits._
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
println(
"------------- dfTrain -------------")
var
dfTrain
=
Seq(
(
1,
5.1,
"a",
"hello",
0.2,
0),
(
2,
4.9,
"b",
null,
0.2,
1),
(
3,
4.7,
"b",
"hi",
0.2,
0),
(
4,
4.6,
"c",
"hello",
0.2,
1)
).
toDF(
"id",
"x1",
"x2",
"x3",
"x4",
"label")
dfTrain.
show()
println(
"------------- dfTest -------------")
var
dfTest
=
Seq(
(
1,
5.1,
"a",
"hello",
0.2,
0),
(
2,
4.9,
"b",
"no",
0.2,
1),
(
3,
4.7,
"a",
"yes",
0.2,
0),
(
4,
4.6,
"d",
"hello",
0.2,
1)
).
toDF(
"id",
"x1",
"x2",
"x3",
"x4",
"label")
// Test set directlycopy就行了,仅用来测试
dfTest.
show()
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
------------- dfTrain -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| null|0.2| 1|
| 3|4.7| b| hi|0.2| 0|
| 4|4.6| c|hello|0.2| 1|
+---+---+---+-----+---+-----+
------------- dfTest -------------
+---+---+---+-----+---+-----+
| id| x1| x2| x3| x4|label|
+---+---+---+-----+---+-----+
| 1|5.1| a|hello|0.2| 0|
| 2|4.9| b| no|0.2| 1|
| 3|4.7| a| yes|0.2| 0|
| 4|4.6| d|hello|0.2| 1|
+---+---+---+-----+---+-----+
dfTrain: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
dfTest: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 4 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
- 27.
注意:
- dTestappeared in the categorical features of dTrain中没有的value
- x2中的d,x3中的no,yes
2 StringIndexer转换
StringIndexerCan encode character fields into tag indices,索引的范围为0to the number of labels(这里指的是dfTrain中了),The order of index building is the frequency of the labels,Labels with higher encoding frequency are preferentially encoded,So the tag with the highest frequency is 0号.如果输入的是数值型的,We will convert it to character type,Then encode it.
做onehotWe need to do that beforestringIndex转换.
val
columns
=
Array(
"x2",
"x3")
val
indexers:
Array[
PipelineStage]
=
columns.
map {
colName
=>
new
StringIndexer().
setInputCol(
colName).
setOutputCol(
colName
+
"_indexed").
setHandleInvalid(
"keep")
}
new
Pipeline().
setStages(
indexers).
fit(
dfTrain).
transform(
dfTrain).
show()
- 1.
- 2.
- 3.
- 4.
- 5.
+---+---+---+-----+---+-----+----------+----------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed|
+---+---+---+-----+---+-----+----------+----------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|
+---+---+---+-----+---+-----+----------+----------+
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_2ab569f4651d, strIdx_4d7196a2a72b)
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- StringIndexr和常用的OneHotEncoderEstimator、VectorAssembler不同,Only a single feature can be processed
- 上面用mapmethod is simple
- 从x2look at the encoding,the most frequentb标记为0,x3的helloedited to0
- x3Missing values of are marked as 2.0
3 OneHotEncoderEstimator
The following is the transformation of the dataonehot编码
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|
| 2|4.9| b| null|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|
| 3|4.7| b| hi|0.2| 0| 0.0| 1.0|(3,[0],[1.0])|(2,[1],[1.0])|
| 4|4.6| c|hello|0.2| 1| 2.0| 0.0|(3,[2],[1.0])|(2,[0],[1.0])|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+
dfTrain1: org.apache.spark.sql.DataFrame = [id: int, x1: double ... 6 more fields]
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 从数据上看,Not very intuitive yet.We are used to the way of data presentation in row and column
- x2_onehot中,3表示一共有3个值(a,b,c),The second indicates which index respectively,第三个都是1
- x3_onehot中,id=2is a bit special,均为[],应该是指null不算,That is, the number of values is 2,(0,1)
The data is transformed below,把vector转换成array,看起来更直观一些
- 可以看到x2_onehot按x2_indexed的[0,1,2]Generate three columns,值为0,1,hit it1否则为0
- 而对于id=2的x3,确实都是0,0,即onehot其实不考虑null值
If there are values in the test set that do not appear in the training set,onehotWhen the value of all columns are automatically changed0,And alsonull处理方法一致.直接看个例子
val
columns
=
Array(
"x2",
"x3")
val
indexers:
Array[
PipelineStage]
=
columns.
map {
colName
=>
new
StringIndexer().
setInputCol(
colName).
setOutputCol(
colName
+
"_indexed").
setHandleInvalid(
"keep")
}
val
onehoter
=
new
OneHotEncoderEstimator()
.
setInputCols(
columns.
map(
_
+
"_indexed"))
.
setOutputCols(
columns.
map(
x
=>
x
+
"_onehot"))
val
featureCol
=
Array(
"x1",
"x4")
++
columns.
map(
x
=>
x
+
"_onehot")
val
assemble
=
new
VectorAssembler()
.
setInputCols(
featureCol)
.
setOutputCol(
"features")
val
pipeline
=
new
Pipeline().
setStages(
indexers
++
Array(
onehoter,
assemble))
val
p1
=
pipeline.
fit(
dfTrain)
p1.
transform(
dfTest).
show()
// Convert the feature column to array,再打印出来
println(
"--------------------------------")
p1.
transform(
dfTest).
select(
"features").
map(
x
=>
x(
0).
asInstanceOf[
Vector].
toArray).
take(
4).
foreach(
x
=>
println(
x.
mkString(
",")))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| id| x1| x2| x3| x4|label|x2_indexed|x3_indexed| x2_onehot| x3_onehot| features|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
| 1|5.1| a|hello|0.2| 0| 1.0| 0.0|(3,[1],[1.0])|(2,[0],[1.0])|[5.1,0.2,0.0,1.0,...|
| 2|4.9| b| no|0.2| 1| 0.0| 2.0|(3,[0],[1.0])| (2,[],[])|(7,[0,1,2],[4.9,0...|
| 3|4.7| a| yes|0.2| 0| 1.0| 2.0|(3,[1],[1.0])| (2,[],[])|(7,[0,1,3],[4.7,0...|
| 4|4.6| d|hello|0.2| 1| 3.0| 0.0| (3,[],[])|(2,[0],[1.0])|(7,[0,1,5],[4.6,0...|
+---+---+---+-----+---+-----+----------+----------+-------------+-------------+--------------------+
--------------------------------
5.1,0.2,0.0,1.0,0.0,1.0,0.0
4.9,0.2,1.0,0.0,0.0,0.0,0.0
4.7,0.2,0.0,1.0,0.0,0.0,0.0
4.6,0.2,0.0,0.0,0.0,1.0,0.0
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_e2e4f73059df, strIdx_581f33be39be)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_f134aaddf69e
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_2ec94bc0ff67
pipeline: org.apache.spark.ml.Pipeline = pipeline_cdcfa4e895ce
p1: org.apache.spark.ml.PipelineModel = pipeline_cdcfa4e895ce
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
- 20.
- 21.
- 22.
- 23.
- 24.
- 25.
- 26.
先看features列
- id=1没啥好说的,id=2,3,4All are shorthand
- 以id=3为例,7表示一共7个元素,其中第[0,1,3]是[4.7,0.2,0.0,1],别的都是0
- id=3的x3_onehot中,yesis not present in the training set,所以onehot之后都是0
The following printout intuitively illustrates the specific way of data storage
4 Pipeline
用PipelineString together the process of data processing,简单跑个demo
val
columns
=
Array(
"x2",
"x3")
val
indexers:
Array[
PipelineStage]
=
columns.
map {
colName
=>
new
StringIndexer().
setInputCol(
colName).
setOutputCol(
colName
+
"_indexed").
setHandleInvalid(
"keep")
}
val
onehoter
=
new
OneHotEncoderEstimator()
.
setInputCols(
columns.
map(
_
+
"_indexed"))
.
setOutputCols(
columns.
map(
x
=>
x
+
"_onehot"))
val
featureCol
=
Array(
"x1",
"x4")
++
columns.
map(
x
=>
x
+
"_onehot")
val
assemble
=
new
VectorAssembler()
.
setInputCols(
featureCol)
.
setOutputCol(
"features")
val
lr
=
new
LogisticRegression().
setMaxIter(
10).
setRegParam(
0.01)
val
pipeline
=
new
Pipeline().
setStages(
indexers
++
Array(
onehoter,
assemble,
lr))
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
columns: Array[String] = Array(x2, x3)
indexers: Array[org.apache.spark.ml.PipelineStage] = Array(strIdx_29b5d6cde551, strIdx_1fa35b31e12b)
onehoter: org.apache.spark.ml.feature.OneHotEncoderEstimator = oneHotEncoder_00efb376d02c
featureCol: Array[String] = Array(x1, x4, x2_onehot, x3_onehot)
assemble: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_78da110d20bf
lr: org.apache.spark.ml.classification.LogisticRegression = logreg_f316005bcb94
pipeline: org.apache.spark.ml.Pipeline = pipeline_e4e6f0fbeabe
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
println(
"-------------- Train -------------- ")
pipeModel.
transform(
dfTrain).
select(
"id",
"features",
"rawPrediction",
"probability",
"prediction").
show()
println(
"-------------- Test -------------- ")
pipeModel.
transform(
dfTest).
select(
"id",
"features",
"rawPrediction",
"probability",
"prediction").
show()
- 1.
- 2.
- 3.
- 4.
-------------- Train --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|[4.7,0.2,1.0,0.0,...|[2.86186371148445...|[0.94592870313351...| 0.0|
| 4|[4.6,0.2,0.0,0.0,...|[-3.5623891952210...|[0.02758825464682...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
-------------- Test --------------
+---+--------------------+--------------------+--------------------+----------+
| id| features| rawPrediction| probability|prediction|
+---+--------------------+--------------------+--------------------+----------+
| 1|[5.1,0.2,0.0,1.0,...|[3.14394279953138...|[0.95866938668285...| 0.0|
| 2|(7,[0,1,2],[4.9,0...|[-2.6533348312718...|[0.06578376621596...| 1.0|
| 3|(7,[0,1,3],[4.7,0...|[2.40266618214349...|[0.91703038740784...| 0.0|
| 4|(7,[0,1,5],[4.6,0...|[-0.7089322949842...|[0.32983480687831...| 1.0|
+---+--------------------+--------------------+--------------------+----------+
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 18.
- 19.
总结下:
- Scala做onehotDon't worry about appearing in the test set and not in the training setvalue
- 用pipeline更方便
Ref
[2] https://spark.apache.org/docs/2.4.4/api/scala/index.html#org.apache.spark.ml.feature.StringIndexer
2020-04-02 于南京市江宁区九龙湖
边栏推荐
猜你喜欢
随机推荐
通俗易懂-二维数组只能省略行不能省略列-人话版本
FE01_OneHot-Scala应用
buuctf(探险1)
谷歌开源芯片 180 纳米制造工艺
SQL优化最全总结 - MySQL(2022最新版)
2019年海淀区青少年程序设计挑战活动小学组复赛试题详细答案
小程序学习目标
使用scikit-learn计算文本TF-IDF值
电源测试系统-ATE电源测试系统-ACDC电源模块测试系统NSAT-8000
从-99打造Sentinel高可用集群限流中间件
使用Redis做某个时间段在线数统计
CAS:385437-57-0,DSPE-PEG-Biotin,生物活性分子磷脂-聚乙二醇-生物素
对象实例化之后一定会存放在堆内存中?
Web端即时通讯技术:WebSocket、socket.io、SSE
信息系统项目管理师必背核心考点(六十)项目集管理
【技术积累】JS事件循环,Promise,async/await的运行顺序
【技术笔记】let 和 var和const的异同
darknet源码阅读笔记-02-list.h和lish.c
容器化 | 在 NFS 备份恢复 RadonDB MySQL 集群数据
哈夫曼树(暑假每日一题 15)









