当前位置:网站首页>协同过滤进化版本NeuralCF及tensorflow2实现
协同过滤进化版本NeuralCF及tensorflow2实现
2022-06-26 21:27:00 【浪漫的数据分析】
目标:
掌握NeuralCF比传统基于矩阵分解的协同过滤算法的改进点,以及算法的优点和缺点。
内容:
上篇学习了最经典的推荐算法:协同过滤,并基于矩阵分解得到了用户和物品的embeding向量。通过点积可以得到两者的相似度,可进行排序推荐。但传统协同过滤通过直接利用非常稀疏的共现矩阵进行预测的,所以模型的泛化能力非常弱,遇到历史行为非常少的用户,就没法产生准确的推荐结果了。矩阵分解是利用非常简单的内积方式来处理用户向量和物品向量的交叉问题的,所以,它的拟合能力也比较弱。
- 改进点
1、 能不能利用深度学习来改进协同过滤算法呢?包括计算embeding向量,和最后计算物品与用户相似度的点积。
2、 新加坡国立的研究者就使用深度学习网络来改进了传统的协同过滤算法,取名 NeuralCF(神经网络协同过滤)
算法思想:
对比几种算法思想
- 1、矩阵分解算法的原理

就是把共线矩阵分解成两个小矩阵相乘,小矩阵就是embeding向量。 - 2、传统的点积求相似度

- 3、 NeuralCF基本思想

改进点,就是用MLP替代原来的点积操作。 - 4、改进版本-双塔模型

- 用户侧的Layer的输出就当做用户侧embeding。
- 物品侧的Layer的输出就当做物品侧embeding。
- 优点:可以缓存物品、用户侧embeding,在线上推荐时。直接用物品、用户侧embeding计算点积得到相似度。
- 5、改进版本2-双塔模型+MLP
点积操作还是过于简单,不便于发现。采用MLP替换点积操作。
- 6、改进版本6-双塔模型+多特征组合+MLP
embeding只用了用户的id或者共线矩阵产生,忽略了物品和用户的其他固有属性,使用的特征过少,因此,可以加入更多特征一起输入到用户侧和物品侧的多层神经网络。这样可以充分利用特征。
模型代码:
GitHub地址:github源码
例如:
1、 NeuralCF基本模型
# neural cf model arch two. only embedding in each tower, then MLP as the interaction layers
def neural_cf_model_1(feature_inputs, item_feature_columns, user_feature_columns, hidden_units):
item_tower = tf.keras.layers.DenseFeatures(item_feature_columns)(feature_inputs)
user_tower = tf.keras.layers.DenseFeatures(user_feature_columns)(feature_inputs)
interact_layer = tf.keras.layers.concatenate([item_tower, user_tower])
for num_nodes in hidden_units:
interact_layer = tf.keras.layers.Dense(num_nodes, activation='relu')(interact_layer)
output_layer = tf.keras.layers.Dense(1, activation='sigmoid')(interact_layer)
neural_cf_model = tf.keras.Model(feature_inputs, output_layer)
return neural_cf_model
2、改进版本-双塔模型
# neural cf model arch one. embedding+MLP in each tower, then dot product layer as the output
def neural_cf_model_2(feature_inputs, item_feature_columns, user_feature_columns, hidden_units):
item_tower = tf.keras.layers.DenseFeatures(item_feature_columns)(feature_inputs)
for num_nodes in hidden_units:
item_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(item_tower)
user_tower = tf.keras.layers.DenseFeatures(user_feature_columns)(feature_inputs)
for num_nodes in hidden_units:
user_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(user_tower)
output = tf.keras.layers.Dot(axes=1)([item_tower, user_tower])
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)
# output = tf.keras.layers.Dense(1)(output)
neural_cf_model = tf.keras.Model(feature_inputs, output)
return neural_cf_model

从结果可以看出,accuracy不是很高,模型欠拟合较严重。
3、 改进版本2-双塔模型+MLP
# neural cf model arch one. embedding+MLP in each tower, then MLP layer as the output
def neural_cf_model_3(feature_inputs, item_feature_columns, user_feature_columns, hidden_units):
item_tower = tf.keras.layers.DenseFeatures(item_feature_columns)(feature_inputs)
for num_nodes in hidden_units:
item_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(item_tower)
user_tower = tf.keras.layers.DenseFeatures(user_feature_columns)(feature_inputs)
for num_nodes in hidden_units:
user_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(user_tower)
output = tf.keras.layers.concatenate([item_tower, user_tower])
# output = tf.keras.layers.Dot(axes=1)([item_tower, user_tower])
for num_nodes in hidden_units:
output = tf.keras.layers.Dense(num_nodes,activation='relu')(output)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)
# output = tf.keras.layers.Dense(1)(output)
neural_cf_model = tf.keras.Model(feature_inputs, output)
return neural_cf_model

从运行结果看,这个模型的loss减小,准确度有提升。
Test Loss 0.19877538084983826, Test Accuracy 0.6881847977638245, Test ROC AUC 0.7592607140541077, Test PR AUC 0.7094590663909912
4、 改进版本6-双塔模型+多特征组合+MLP
终极版本:
# neural cf model arch one. embedding+MLP in each tower, then MLP layer as the output
def neural_cf_model_4(feature_inputs, item_feature_columns, user_feature_columns, hidden_units):
item_tower = tf.keras.layers.DenseFeatures(item_feature_columns)(feature_inputs)
item_tower = tf.keras.layers.concatenate([item_tower,iterm_f])
for num_nodes in hidden_units:
item_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(item_tower)
user_tower = tf.keras.layers.DenseFeatures(user_feature_columns)(feature_inputs)
user_tower = tf.keras.layers.concatenate([user_tower,user_f])
for num_nodes in hidden_units:
user_tower = tf.keras.layers.Dense(num_nodes, activation='relu')(user_tower)
output = tf.keras.layers.concatenate([item_tower, user_tower])
# output = tf.keras.layers.Dot(axes=1)([item_tower, user_tower])
for num_nodes in hidden_units:
output = tf.keras.layers.Dense(num_nodes,activation='relu')(output)
output = tf.keras.layers.Dense(1, activation='sigmoid')(output)
# output = tf.keras.layers.Dense(1)(output)
neural_cf_model = tf.keras.Model(feature_inputs, output)
return neural_cf_model
最终运行结果:
Test Loss 0.6841861605644226, Test Accuracy 0.6669825315475464, Test ROC AUC 0.715860903263092, Test PR AUC 0.6257403492927551
效果和第三种相差不大,但是当数据量多的时候,理论上,第4种效果最好。
边栏推荐
- 后台查找,如何查找网站后台
- Gamefi active users, transaction volume, financing amount and new projects continue to decline. Can axie and stepn get rid of the death spiral? Where is the chain tour?
- 【 protobuf 】 quelques puits causés par la mise à niveau de protobuf
- Operator介绍
- 龙芯中科科创板上市:市值357亿 成国产CPU第一股
- DAST 黑盒漏洞扫描器 第五篇:漏洞扫描引擎与服务能力
- 【贝叶斯分类4】贝叶斯网
- Y48. Chapter III kubernetes from introduction to mastery -- pod status and probe (21)
- What are the accounting elements
- 在哪家证券公司开户最方便最安全可靠
猜你喜欢

VB.net类库(进阶版——1)

What are the accounting elements

leetcode刷题:字符串05(剑指 Offer 58 - II. 左旋转字符串)

Kdd2022 𞓜 unified session recommendation system based on knowledge enhancement prompt learning

【protobuf 】protobuf 升级后带来的一些坑

众多碎石3d材质贴图素材一键即可获取

The importance of using fonts correctly in DataWindow

Redis + Guava 本地缓存 API 组合,性能炸裂!

网易云信正式加入中国医学装备协会智慧医院分会,为全国智慧医院建设加速...

Looking back at the moon
随机推荐
0 basic C language (0)
Mongodb implements creating and deleting databases, creating and deleting tables (sets), and adding, deleting, modifying, and querying data
Stringutils judge whether the string is empty
Muke 8. Service fault tolerance Sentinel
[protobuf] some pits brought by protobuf upgrade
Android mediacodec hard coded H264 file (four), ByteDance Android interview
BN(Batch Normalization) 的理论理解以及在tf.keras中的实际应用和总结
JWT operation tool class sharing
Sword finger offer II 098 Number of paths / Sword finger offer II 099 Sum of minimum paths
C: Reverse linked list
About appium trample pit: encountered internal error running command: error: cannot verify the signature of (solved)
基于Qt实现的“合成大西瓜”小游戏
Twenty five of offer - all paths with a certain value in the binary tree
VB.net类库(进阶版——1)
Is there any risk in opening a mobile stock registration account? Is it safe?
【贝叶斯分类3】半朴素贝叶斯分类器
Two methods of QT to realize timer
宝藏又小众的覆盖物PBR多通道贴图素材网站分享
KDD2022 | 基于知识增强提示学习的统一会话推荐系统
VB.net类库,获取屏幕内鼠标下的颜色(进阶——3)