当前位置:网站首页>推荐模型之多任务模型:ESMM、MMOE
推荐模型之多任务模型:ESMM、MMOE
2022-06-24 19:19:00 【莱维贝贝、】
1. ESMM
ESMM
的全称是Entire Space Multi-task Model (ESMM),是阿里巴巴算法团队提出的多任务训练方法。其在信息检索、推荐系统、在线广告投放系统的CTR、CVR预估中广泛使用。以电商推荐系统为例,最大化场景商品交易总额(GMV)是平台的重要目标之一,而GMV可以拆解为流量×点击率×转化率×客单价,因此转化率是优化目标的重要因子之一; 从用户体验的角度来说转换率可以用来平衡用户的点击偏好与购买偏好。
传统的CVR预估任务
,其存在如下问题:
- 样本选择偏差: 构建的训练样本集的分布采样与实际数据分布存在偏差;
- 稀疏数据: 点击样本占曝光样本的比例很小
1.2 ESMM原理
ESMM模型利用用户行为序列数据在完整样本空间建模,避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题,取得了显著的效果。另一方面,ESMM模型首次提出了利用学习CTR和CTCVR的辅助任务迂回学习CVR的思路
1.2.1 ESMM模型框架
CTR = 实际点击次数 / 展示量
CVR = 转化数 / 点击量
CTCVR = 转换数 / 曝光量。是预测item被点击,然后被转化的概率。
ESMM模型由两个子网络组成:
- 左边的子网络用来拟合pCVR
- 右边的子网络用来拟合pCTR,同时,两个子网络的输出相乘之后可以得到pCTCVR。因此,该网络结构共有三个子任务,分别用于输出pCTR、pCVR和pCTCVR。
其中x表示曝光,y表示点击,z表示转化
注意:
- 共享Embedding。 CVR-task和CTR-task使用相同的特征和特征embedding,即两者从Concatenate之后才学习各自独享的参数;
- 隐式学习pCVR。 这里pCVR 仅是网络中的一个variable,没有显示的监督信号。
CTCVR和CTR的label构造损失函数:
解决样本选择偏差: 在训练过程中,模型只需要预测pCTCVR和pCTR,即可更新参数,由于pCTCVR和pCTR的数据是基于完整样本空间提取的,故根据公式,可以解决pCVR的样本选择偏差。
解决数据稀疏: 使用共享的embedding层,使得CVR子任务也能够从只展示没点击的样本中学习,可以缓解训练数据稀疏的问题。
1.3 ESMM模型的优化
- 模型优化:论文中,子任务独立的Tower网络是纯MLP模型,可以根据自身特点设置不一样的模型,例如使用DeepFM、DIN等
- 学习优化:引入动态加权的学习机制,优化loss
- 特征优化:可构建更长的序列依赖模型,例如美团AITM信用卡业务,用户转换过程是曝光->点击->申请->核卡->激活
1.4 ESMM代码
import torch
import torch.nn.functional as F
from torch_rechub.basic.layers import MLP, EmbeddingLayer
from tqdm import tqdm
class ESMM(torch.nn.Module):
def __init__(self, user_features, item_features, cvr_params, ctr_params):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.embedding = EmbeddingLayer(user_features + item_features)
self.tower_dims = user_features[0].embed_dim + item_features[0].embed_dim
# 构建CVR和CTR的双塔
self.tower_cvr = MLP(self.tower_dims, **cvr_params)
self.tower_ctr = MLP(self.tower_dims, **ctr_params)
def forward(self, x):
embed_user_features = self.embedding(x, self.user_features,
squeeze_dim=False).sum(dim=1)
embed_item_features = self.embedding(x, self.item_features,
squeeze_dim=False).sum(dim=1)
input_tower = torch.cat((embed_user_features, embed_item_features), dim=1)
cvr_logit = self.tower_cvr(input_tower)
ctr_logit = self.tower_ctr(input_tower)
cvr_pred = torch.sigmoid(cvr_logit)
ctr_pred = torch.sigmoid(ctr_logit)
# 计算pCTCVR = pCTR * pCVR
ctcvr_pred = torch.mul(cvr_pred, cvr_pred)
ys = [cvr_pred, ctr_pred, ctcvr_pred]
return torch.cat(ys, dim=1)
2. MMOE
2.1 MMOE产生背景
- 多任务模型:在不同任务之间学习共性以及差异性,能够提高建模的质量以及效率。
- 传统的多任务模型(Shared Bottom),随着任务之间相关性越小,模型预测效果越差。
- Hard Parameter Sharing方法:底层是共享的隐藏层,学习各个任务的共同模式,上层用一些特定的全连接层学习特定任务模式。
- Soft Parameter Sharing方法:底层不使用共享的shared bottom,而是有多个tower,给不同的tower分配不同的权重。
- 任务序列依赖关系建模:这种适合于不同任务之间有一定的序列依赖关系。
2.2 MMOE模型框架
mmoe的思想就是,在底部的embedding层上设立多个expert网络,不同的task可以根据需求选择expert,通过expert输出的概率加权平均获得每个task的输入。
2.2.1 MOE模型
- 模型原理:基于多个
Expert
汇总输出,通过门控网络机制得到每个Expert
的权重,不同任务中的Expert所占权重不同。 - 特性:模型集成、注意力机制、multi-head机制
下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式.
2.2.2 MMOE模型
模型架构:基于OMOE模型,每个Expert任务都有一个门控网络,多个任务一个网络。
特性:1.避免任务冲突,根据不同的门控进行调整,选择出对当前任务有帮助的Expert组合;2. 建立任务之间的关系,参数共享灵活;3. 训练时模型能够快速收敛
2.3 代码
import torch
import torch.nn as nn
from torch_rechub.basic.layers import MLP, Embedding
class MMOE(torch.nn.Module):
def __init__(self, features, task_types, n_expert, expert_params, tower_params_list):
super().__init__()
self.features = features
self.task_types = task_types
# 任务数量
self.n_task = len(task_types)
self.n_expert = n_expert
self.embedding = EmbeddingLayer(features)
self.input_dims = sum([fea.embed_dim for fea in features])
# 每个Expert对应一个门控
self.experts = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **expert_params) for i in range(self.n_expert))
self.gates = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **{
"dims": [self.n_expert],
"activation": "softmax"
}) for i in range(self.n_task))
# 双塔
self.towers = nn.ModuleList(MLP(expert_params["dims"][-1], **tower_params_list[i]) for i in range(self.n_task))
self.predict_layers = nn.ModuleList(PredictionLayer(task_type) for task_type in task_types)
def forward(self, x):
embed_x = self.embedding(x, self.features, squeeze_dim=True)
expert_outs = [expert(embed_x).unsqueeze(1) for expert in self.experts]
expert_outs = torch.cat(expert_outs, dim=1)
gate_outs = [gate(embed_x).unsqueeze(-1) for gate in self.gates]
ys = []
for gate_out, tower, predict_layer in zip(gate_outs, self.towers, self.predict_layers):
expert_weight = torch.mul(gate_out, expert_outs)
expert_pooling = torch.sum(expert_weight, dim=1)
# 计算双塔
tower_out = tower(expert_pooling)
# logit -> proba
y = predict_layer(tower_out)
ys.append(y)
return torch.cat(ys, dim=1)
3 总结
- ESMM模型: 主要引入CTR和CTCVR的辅助任务,在完整样本空间建模进行训练,避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题。并可根据自身特点设置两个塔的不同模型,子网络支持任意替换。
- MMOE模型:多个任务共享专家,每个门控机制控制专家对每个任务的点评,这样不仅可以改善传统模型存在的问题(任务之间差异性较大,导致模型训练不收敛),还可以训练快速收敛。
参考
https://blog.csdn.net/cuixian123456/article/details/118682085
https://blog.csdn.net/m0_37870649/article/details/87378906
https://zhuanlan.zhihu.com/p/454726579
https://www.jianshu.com/p/0f3e40bfd3ceutm_campaign=haruki&utm_content=note&utm_medium=seo_notes&utm_source=recommendation
https://zhuanlan.zhihu.com/p/263781995
边栏推荐
- Open function
- Summary of idea practical skills: how to rename a project or module to completely solve all the problems you encounter that do not work. It is suggested that the five-star collection be your daughter
- Network security review office starts network security review on HowNet
- Football information query system based on C language course report + project source code + demo ppt+ project screenshot
- 大厂出海,败于“姿态”
- B站带货当学新东方
- go_ keyword
- Page replacement of virtual memory paging mechanism
- Failed to open after installing Charles without any prompt
- What are the problems with traditional IO? Why is zero copy introduced?
猜你喜欢
Basic database syntax learning
What does CTO (technical director) usually do?
Adding subscribers to a list using mailchimp's API V3
Read all text from stdin to a string
基于STM32的物联网下智能化养鱼鱼缸控制控制系统
Simple analysis of WordPress architecture
EditText 控制软键盘出现 搜索
Static routing experiment
A/b test helps the growth of game business
基于C语言实现的足球信息查询系统 课程报告+项目源码+演示PPT+项目截图
随机推荐
Sleep revolution - find the right length of rest
Return of missing persons
VIM usage
Requests requests for web page garbled code resolution
TCP Jprobe utilization problem location
memcached全面剖析–3. memcached的删除机制和发展方向
Simpledateformat thread unsafe
Learn to use a new technology quickly
Common data model (updating)
Basic database syntax learning
装修首页自定义全屏视频播放效果gif动态图片制作视频教程播放代码操作设置全屏居中阿里巴巴国际站
The first day of handwritten RPC -- review of some basic knowledge
Notes_ Vlan
Network security review office starts network security review on HowNet
Reflection - class object function - get method (case)
EditText 控制软键盘出现 搜索
188. 买卖股票的最佳时机 IV
Summary of idea practical skills: how to rename a project or module to completely solve all the problems you encounter that do not work. It is suggested that the five-star collection be your daughter
Shell script
Reflect package