当前位置:网站首页>推荐模型之多任务模型:ESMM、MMOE
推荐模型之多任务模型:ESMM、MMOE
2022-06-24 19:19:00 【莱维贝贝、】
1. ESMM
ESMM的全称是Entire Space Multi-task Model (ESMM),是阿里巴巴算法团队提出的多任务训练方法。其在信息检索、推荐系统、在线广告投放系统的CTR、CVR预估中广泛使用。以电商推荐系统为例,最大化场景商品交易总额(GMV)是平台的重要目标之一,而GMV可以拆解为流量×点击率×转化率×客单价,因此转化率是优化目标的重要因子之一; 从用户体验的角度来说转换率可以用来平衡用户的点击偏好与购买偏好。
传统的CVR预估任务,其存在如下问题:
- 样本选择偏差: 构建的训练样本集的分布采样与实际数据分布存在偏差;
- 稀疏数据: 点击样本占曝光样本的比例很小

1.2 ESMM原理
ESMM模型利用用户行为序列数据在完整样本空间建模,避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题,取得了显著的效果。另一方面,ESMM模型首次提出了利用学习CTR和CTCVR的辅助任务迂回学习CVR的思路
1.2.1 ESMM模型框架

CTR = 实际点击次数 / 展示量
CVR = 转化数 / 点击量
CTCVR = 转换数 / 曝光量。是预测item被点击,然后被转化的概率。
ESMM模型由两个子网络组成:
- 左边的子网络用来拟合pCVR
- 右边的子网络用来拟合pCTR,同时,两个子网络的输出相乘之后可以得到pCTCVR。因此,该网络结构共有三个子任务,分别用于输出pCTR、pCVR和pCTCVR。

其中x表示曝光,y表示点击,z表示转化
注意:
- 共享Embedding。 CVR-task和CTR-task使用相同的特征和特征embedding,即两者从Concatenate之后才学习各自独享的参数;
- 隐式学习pCVR。 这里pCVR 仅是网络中的一个variable,没有显示的监督信号。
CTCVR和CTR的label构造损失函数:
解决样本选择偏差: 在训练过程中,模型只需要预测pCTCVR和pCTR,即可更新参数,由于pCTCVR和pCTR的数据是基于完整样本空间提取的,故根据公式,可以解决pCVR的样本选择偏差。
解决数据稀疏: 使用共享的embedding层,使得CVR子任务也能够从只展示没点击的样本中学习,可以缓解训练数据稀疏的问题。
1.3 ESMM模型的优化
- 模型优化:论文中,子任务独立的Tower网络是纯MLP模型,可以根据自身特点设置不一样的模型,例如使用DeepFM、DIN等
- 学习优化:引入动态加权的学习机制,优化loss
- 特征优化:可构建更长的序列依赖模型,例如美团AITM信用卡业务,用户转换过程是曝光->点击->申请->核卡->激活
1.4 ESMM代码
import torch
import torch.nn.functional as F
from torch_rechub.basic.layers import MLP, EmbeddingLayer
from tqdm import tqdm
class ESMM(torch.nn.Module):
def __init__(self, user_features, item_features, cvr_params, ctr_params):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.embedding = EmbeddingLayer(user_features + item_features)
self.tower_dims = user_features[0].embed_dim + item_features[0].embed_dim
# 构建CVR和CTR的双塔
self.tower_cvr = MLP(self.tower_dims, **cvr_params)
self.tower_ctr = MLP(self.tower_dims, **ctr_params)
def forward(self, x):
embed_user_features = self.embedding(x, self.user_features,
squeeze_dim=False).sum(dim=1)
embed_item_features = self.embedding(x, self.item_features,
squeeze_dim=False).sum(dim=1)
input_tower = torch.cat((embed_user_features, embed_item_features), dim=1)
cvr_logit = self.tower_cvr(input_tower)
ctr_logit = self.tower_ctr(input_tower)
cvr_pred = torch.sigmoid(cvr_logit)
ctr_pred = torch.sigmoid(ctr_logit)
# 计算pCTCVR = pCTR * pCVR
ctcvr_pred = torch.mul(cvr_pred, cvr_pred)
ys = [cvr_pred, ctr_pred, ctcvr_pred]
return torch.cat(ys, dim=1)
2. MMOE
2.1 MMOE产生背景
- 多任务模型:在不同任务之间学习共性以及差异性,能够提高建模的质量以及效率。
- 传统的多任务模型(Shared Bottom),随着任务之间相关性越小,模型预测效果越差。

- Hard Parameter Sharing方法:底层是共享的隐藏层,学习各个任务的共同模式,上层用一些特定的全连接层学习特定任务模式。
- Soft Parameter Sharing方法:底层不使用共享的shared bottom,而是有多个tower,给不同的tower分配不同的权重。
- 任务序列依赖关系建模:这种适合于不同任务之间有一定的序列依赖关系。
2.2 MMOE模型框架
mmoe的思想就是,在底部的embedding层上设立多个expert网络,不同的task可以根据需求选择expert,通过expert输出的概率加权平均获得每个task的输入。
2.2.1 MOE模型
- 模型原理:基于多个
Expert汇总输出,通过门控网络机制得到每个Expert的权重,不同任务中的Expert所占权重不同。 - 特性:模型集成、注意力机制、multi-head机制
下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式.

2.2.2 MMOE模型

模型架构:基于OMOE模型,每个Expert任务都有一个门控网络,多个任务一个网络。
特性:1.避免任务冲突,根据不同的门控进行调整,选择出对当前任务有帮助的Expert组合;2. 建立任务之间的关系,参数共享灵活;3. 训练时模型能够快速收敛
2.3 代码
import torch
import torch.nn as nn
from torch_rechub.basic.layers import MLP, Embedding
class MMOE(torch.nn.Module):
def __init__(self, features, task_types, n_expert, expert_params, tower_params_list):
super().__init__()
self.features = features
self.task_types = task_types
# 任务数量
self.n_task = len(task_types)
self.n_expert = n_expert
self.embedding = EmbeddingLayer(features)
self.input_dims = sum([fea.embed_dim for fea in features])
# 每个Expert对应一个门控
self.experts = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **expert_params) for i in range(self.n_expert))
self.gates = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **{
"dims": [self.n_expert],
"activation": "softmax"
}) for i in range(self.n_task))
# 双塔
self.towers = nn.ModuleList(MLP(expert_params["dims"][-1], **tower_params_list[i]) for i in range(self.n_task))
self.predict_layers = nn.ModuleList(PredictionLayer(task_type) for task_type in task_types)
def forward(self, x):
embed_x = self.embedding(x, self.features, squeeze_dim=True)
expert_outs = [expert(embed_x).unsqueeze(1) for expert in self.experts]
expert_outs = torch.cat(expert_outs, dim=1)
gate_outs = [gate(embed_x).unsqueeze(-1) for gate in self.gates]
ys = []
for gate_out, tower, predict_layer in zip(gate_outs, self.towers, self.predict_layers):
expert_weight = torch.mul(gate_out, expert_outs)
expert_pooling = torch.sum(expert_weight, dim=1)
# 计算双塔
tower_out = tower(expert_pooling)
# logit -> proba
y = predict_layer(tower_out)
ys.append(y)
return torch.cat(ys, dim=1)
3 总结
- ESMM模型: 主要引入CTR和CTCVR的辅助任务,在完整样本空间建模进行训练,避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题。并可根据自身特点设置两个塔的不同模型,子网络支持任意替换。
- MMOE模型:多个任务共享专家,每个门控机制控制专家对每个任务的点评,这样不仅可以改善传统模型存在的问题(任务之间差异性较大,导致模型训练不收敛),还可以训练快速收敛。
参考
https://blog.csdn.net/cuixian123456/article/details/118682085
https://blog.csdn.net/m0_37870649/article/details/87378906
https://zhuanlan.zhihu.com/p/454726579
https://www.jianshu.com/p/0f3e40bfd3ceutm_campaign=haruki&utm_content=note&utm_medium=seo_notes&utm_source=recommendation
https://zhuanlan.zhihu.com/p/263781995
边栏推荐
- Common data model (updating)
- Web automation: summary of special scenario processing methods
- Oauth2.0 introduction
- Web automation: web control interaction / multi window processing / Web page frame
- 虚拟货币7个月蒸发2万亿美元,“马斯克们”终结15万人暴富梦
- Three more days
- Tool composition in JMeter
- 浅谈MySql update会锁定哪些范围的数据
- Codeforces Round #720 (Div. 2)
- 装修首页自定义全屏视频播放效果gif动态图片制作视频教程播放代码操作设置全屏居中阿里巴巴国际站
猜你喜欢

Nifi quick installation (stand-alone / cluster)

ping: www.baidu.com: 未知的名称或服务

Please open online PDF carefully

Limit summary (under update)

JMeter implementation specifies concurrent loop testing

Address mapping of virtual memory paging mechanism

The first day of handwritten RPC -- review of some basic knowledge

Common data model (updating)

Summary of idea practical skills: how to rename a project or module to completely solve all the problems you encounter that do not work. It is suggested that the five-star collection be your daughter

Big factories go out to sea and lose "posture"
随机推荐
Limit summary (under update)
OSI and tcp/ip model
Reflect package
EditText 控制软键盘出现 搜索
Apple mobile phone can see some fun ways to install IPA package
装修首页自定义全屏视频播放效果gif动态图片制作视频教程播放代码操作设置全屏居中阿里巴巴国际站
Poj1061 frog dating (extended Euclid)
Subnet partition operation
Simple analysis of WordPress architecture
Notes_ Vlan
Static routing job
Am, FM, PM modulation technology
Foundations of Cryptography
Nifi fast authentication configuration
Please open online PDF carefully
Record a deletion bash_ Profile file
After idea installs these plug-ins, the code can be written to heaven. My little sister also has to arrange it
Golang reflection operation collation
Simpledateformat thread unsafe
使用 Go 编程语言 66 个陷阱:Golang 开发者的陷阱和常见错误指北