当前位置:网站首页>推荐模型之多任务模型:ESMM、MMOE
推荐模型之多任务模型:ESMM、MMOE
2022-06-24 19:19:00 【莱维贝贝、】
1. ESMM
ESMM的全称是Entire Space Multi-task Model (ESMM),是阿里巴巴算法团队提出的多任务训练方法。其在信息检索、推荐系统、在线广告投放系统的CTR、CVR预估中广泛使用。以电商推荐系统为例,最大化场景商品交易总额(GMV)是平台的重要目标之一,而GMV可以拆解为流量×点击率×转化率×客单价,因此转化率是优化目标的重要因子之一; 从用户体验的角度来说转换率可以用来平衡用户的点击偏好与购买偏好。
传统的CVR预估任务,其存在如下问题:
- 样本选择偏差: 构建的训练样本集的分布采样与实际数据分布存在偏差;
- 稀疏数据: 点击样本占曝光样本的比例很小

1.2 ESMM原理
ESMM模型利用用户行为序列数据在完整样本空间建模,避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题,取得了显著的效果。另一方面,ESMM模型首次提出了利用学习CTR和CTCVR的辅助任务迂回学习CVR的思路
1.2.1 ESMM模型框架

CTR = 实际点击次数 / 展示量
CVR = 转化数 / 点击量
CTCVR = 转换数 / 曝光量。是预测item被点击,然后被转化的概率。
ESMM模型由两个子网络组成:
- 左边的子网络用来拟合pCVR
- 右边的子网络用来拟合pCTR,同时,两个子网络的输出相乘之后可以得到pCTCVR。因此,该网络结构共有三个子任务,分别用于输出pCTR、pCVR和pCTCVR。

其中x表示曝光,y表示点击,z表示转化
注意:
- 共享Embedding。 CVR-task和CTR-task使用相同的特征和特征embedding,即两者从Concatenate之后才学习各自独享的参数;
- 隐式学习pCVR。 这里pCVR 仅是网络中的一个variable,没有显示的监督信号。
CTCVR和CTR的label构造损失函数:
解决样本选择偏差: 在训练过程中,模型只需要预测pCTCVR和pCTR,即可更新参数,由于pCTCVR和pCTR的数据是基于完整样本空间提取的,故根据公式,可以解决pCVR的样本选择偏差。
解决数据稀疏: 使用共享的embedding层,使得CVR子任务也能够从只展示没点击的样本中学习,可以缓解训练数据稀疏的问题。
1.3 ESMM模型的优化
- 模型优化:论文中,子任务独立的Tower网络是纯MLP模型,可以根据自身特点设置不一样的模型,例如使用DeepFM、DIN等
- 学习优化:引入动态加权的学习机制,优化loss
- 特征优化:可构建更长的序列依赖模型,例如美团AITM信用卡业务,用户转换过程是曝光->点击->申请->核卡->激活
1.4 ESMM代码
import torch
import torch.nn.functional as F
from torch_rechub.basic.layers import MLP, EmbeddingLayer
from tqdm import tqdm
class ESMM(torch.nn.Module):
def __init__(self, user_features, item_features, cvr_params, ctr_params):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.embedding = EmbeddingLayer(user_features + item_features)
self.tower_dims = user_features[0].embed_dim + item_features[0].embed_dim
# 构建CVR和CTR的双塔
self.tower_cvr = MLP(self.tower_dims, **cvr_params)
self.tower_ctr = MLP(self.tower_dims, **ctr_params)
def forward(self, x):
embed_user_features = self.embedding(x, self.user_features,
squeeze_dim=False).sum(dim=1)
embed_item_features = self.embedding(x, self.item_features,
squeeze_dim=False).sum(dim=1)
input_tower = torch.cat((embed_user_features, embed_item_features), dim=1)
cvr_logit = self.tower_cvr(input_tower)
ctr_logit = self.tower_ctr(input_tower)
cvr_pred = torch.sigmoid(cvr_logit)
ctr_pred = torch.sigmoid(ctr_logit)
# 计算pCTCVR = pCTR * pCVR
ctcvr_pred = torch.mul(cvr_pred, cvr_pred)
ys = [cvr_pred, ctr_pred, ctcvr_pred]
return torch.cat(ys, dim=1)
2. MMOE
2.1 MMOE产生背景
- 多任务模型:在不同任务之间学习共性以及差异性,能够提高建模的质量以及效率。
- 传统的多任务模型(Shared Bottom),随着任务之间相关性越小,模型预测效果越差。

- Hard Parameter Sharing方法:底层是共享的隐藏层,学习各个任务的共同模式,上层用一些特定的全连接层学习特定任务模式。
- Soft Parameter Sharing方法:底层不使用共享的shared bottom,而是有多个tower,给不同的tower分配不同的权重。
- 任务序列依赖关系建模:这种适合于不同任务之间有一定的序列依赖关系。
2.2 MMOE模型框架
mmoe的思想就是,在底部的embedding层上设立多个expert网络,不同的task可以根据需求选择expert,通过expert输出的概率加权平均获得每个task的输入。
2.2.1 MOE模型
- 模型原理:基于多个
Expert汇总输出,通过门控网络机制得到每个Expert的权重,不同任务中的Expert所占权重不同。 - 特性:模型集成、注意力机制、multi-head机制
下图展示了一个有三个专家的两路数据并行MoE模型进行前向计算的方式.

2.2.2 MMOE模型

模型架构:基于OMOE模型,每个Expert任务都有一个门控网络,多个任务一个网络。
特性:1.避免任务冲突,根据不同的门控进行调整,选择出对当前任务有帮助的Expert组合;2. 建立任务之间的关系,参数共享灵活;3. 训练时模型能够快速收敛
2.3 代码
import torch
import torch.nn as nn
from torch_rechub.basic.layers import MLP, Embedding
class MMOE(torch.nn.Module):
def __init__(self, features, task_types, n_expert, expert_params, tower_params_list):
super().__init__()
self.features = features
self.task_types = task_types
# 任务数量
self.n_task = len(task_types)
self.n_expert = n_expert
self.embedding = EmbeddingLayer(features)
self.input_dims = sum([fea.embed_dim for fea in features])
# 每个Expert对应一个门控
self.experts = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **expert_params) for i in range(self.n_expert))
self.gates = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **{
"dims": [self.n_expert],
"activation": "softmax"
}) for i in range(self.n_task))
# 双塔
self.towers = nn.ModuleList(MLP(expert_params["dims"][-1], **tower_params_list[i]) for i in range(self.n_task))
self.predict_layers = nn.ModuleList(PredictionLayer(task_type) for task_type in task_types)
def forward(self, x):
embed_x = self.embedding(x, self.features, squeeze_dim=True)
expert_outs = [expert(embed_x).unsqueeze(1) for expert in self.experts]
expert_outs = torch.cat(expert_outs, dim=1)
gate_outs = [gate(embed_x).unsqueeze(-1) for gate in self.gates]
ys = []
for gate_out, tower, predict_layer in zip(gate_outs, self.towers, self.predict_layers):
expert_weight = torch.mul(gate_out, expert_outs)
expert_pooling = torch.sum(expert_weight, dim=1)
# 计算双塔
tower_out = tower(expert_pooling)
# logit -> proba
y = predict_layer(tower_out)
ys.append(y)
return torch.cat(ys, dim=1)
3 总结
- ESMM模型: 主要引入CTR和CTCVR的辅助任务,在完整样本空间建模进行训练,避免了传统CVR模型经常遭遇的样本选择偏差和训练数据稀疏的问题。并可根据自身特点设置两个塔的不同模型,子网络支持任意替换。
- MMOE模型:多个任务共享专家,每个门控机制控制专家对每个任务的点评,这样不仅可以改善传统模型存在的问题(任务之间差异性较大,导致模型训练不收敛),还可以训练快速收敛。
参考
https://blog.csdn.net/cuixian123456/article/details/118682085
https://blog.csdn.net/m0_37870649/article/details/87378906
https://zhuanlan.zhihu.com/p/454726579
https://www.jianshu.com/p/0f3e40bfd3ceutm_campaign=haruki&utm_content=note&utm_medium=seo_notes&utm_source=recommendation
https://zhuanlan.zhihu.com/p/263781995
边栏推荐
- A/B测试助力游戏业务增长
- Reflect package
- An example illustrates restful API
- Nifi quick installation (stand-alone / cluster)
- Arkit与Character Creator动画曲线的对接
- JMeter installation plug-in, adding [email protected] -Perfmon metric collector listener steps
- Analysis of tcpdump packet capturing kernel code
- Typescript syntax
- 网络安全审查办公室对知网启动网络安全审查
- When to send the update windows message
猜你喜欢

After screwing the screws in the factory for two years, I earned more than 10000 yuan a month by "testing" and counterattacked

Procedural life: a few things you should know when entering the workplace

Alibaba cloud schedules tasks and automatically releases them

Address mapping of virtual memory paging mechanism

DHCP operation

memcached完全剖析–1. memcached的基础

OSI notes sorting

I feel that I am bald again when I help my children with their homework. I feel pity for my parents all over the world

Oauth2.0 introduction

Oauth1.0 introduction
随机推荐
Golang daily question
VirtualBox虚拟机安装Win10企业版
JUnit unit test
Static routing experiment
Open function
Different WordPress pages display different gadgets
JMeter parameterization
Variable setting in postman
Rename and delete files
Create a multithreaded thread class
Web automation: web control interaction / multi window processing / Web page frame
Football information query system based on C language course report + project source code + demo ppt+ project screenshot
[cloud native learning notes] learn about kubernetes configuration list yaml file
Network layer
After screwing the screws in the factory for two years, I earned more than 10000 yuan a month by "testing" and counterattacked
It was Tencent who jumped out of the job with 26k. It really wiped my ass with sandpaper. It gave me a hand
What does CTO (technical director) usually do?
memcached完全剖析–1. memcached的基础
memcached全面剖析–2. 理解memcached的内存存储
Poj1061 frog dating (extended Euclid)