当前位置:网站首页>Multi task model of recommended model: esmm, MMOE
Multi task model of recommended model: esmm, MMOE
2022-06-24 21:34:00 【Levi Bebe】
1. ESMM
ESMM The full name is Entire Space Multi-task Model (ESMM), It is a multi task training method proposed by Alibaba algorithm team . In information retrieval 、 Recommendation system 、 Online advertising system CTR、CVR Widely used in estimation . Take e-commerce recommendation system as an example , Maximize the total transaction amount of scenario commodities (GMV) Is one of the important goals of the platform , and GMV Can be disassembled into flow × Click through rate × Conversion rate × Customer unit price , Therefore, the conversion rate is one of the important factors of the optimization objective ; From the perspective of user experience, conversion rate can be used to balance users' click preference and purchase preference .
Conventional CVR Estimated mission , It has the following problems :
- Sample selection bias : The distribution sampling of the constructed training sample set deviates from the actual data distribution ;
- Sparse data : Click samples account for a small proportion of exposure samples

1.2 ESMM principle
ESMM Model Using user behavior sequence data in Complete sample space modeling , Avoided Tradition CVR Sample selection that models often encounter Bias and sparse training data , Achieved remarkable results . On the other hand ,ESMM The model first proposed using learning CTR and CTCVR The auxiliary task of detour learning CVR The idea of
1.2.1 ESMM Model framework

CTR = Actual hits / The amount of display
CVR = Number of conversion / Clicks
CTCVR = Conversion number / Exposure . It's prediction item Clicked , Then the probability of being transformed .
ESMM The model consists of two subnetworks :
- The subnetwork on the left To fit pCVR
- The subnetwork on the right To fit pCTR, meanwhile , After multiplying the outputs of the two subnetworks, we can get pCTCVR. therefore , The network structure has three sub tasks , Used to output pCTR、pCVR and pCTCVR.

among x It means exposure ,y It means to click ,z It means transformation
Be careful :
- share Embedding. CVR-task and CTR-task Use the same features and features embedding, That is, both from Concatenate Then learn their own exclusive parameters ;
- Implicit learning pCVR. here pCVR Just one of the networks variable, There is no monitoring signal shown .
CTCVR and CTR Of label Constructing loss function :
Resolve sample selection bias : In the process of training , The model only needs to predict pCTCVR and pCTR, Parameters can be updated , because pCTCVR and pCTR The data is extracted based on the complete sample space , So according to the formula , Can solve pCVR Sample selection deviation .
Solve the problem of data sparsity : Use shared embedding layer , bring CVR Subtasks can also learn from samples that show only clicks , It can alleviate the problem of sparse training data .
1.3 ESMM Model optimization
- Model optimization : In the paper , Subtasks are independent Tower The network is pure MLP Model , You can set different models according to your own characteristics , For example, using DeepFM、DIN etc.
- Learning optimization : Introduce dynamic weighted learning mechanism , Optimize loss
- Feature optimization : Longer sequence dependency models can be built , For example, meituan AITM Credit card business , The user conversion process is exposure -> Click on -> apply -> Nuclear card -> Activate
1.4 ESMM Code
import torch
import torch.nn.functional as F
from torch_rechub.basic.layers import MLP, EmbeddingLayer
from tqdm import tqdm
class ESMM(torch.nn.Module):
def __init__(self, user_features, item_features, cvr_params, ctr_params):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.embedding = EmbeddingLayer(user_features + item_features)
self.tower_dims = user_features[0].embed_dim + item_features[0].embed_dim
# structure CVR and CTR The twin towers
self.tower_cvr = MLP(self.tower_dims, **cvr_params)
self.tower_ctr = MLP(self.tower_dims, **ctr_params)
def forward(self, x):
embed_user_features = self.embedding(x, self.user_features,
squeeze_dim=False).sum(dim=1)
embed_item_features = self.embedding(x, self.item_features,
squeeze_dim=False).sum(dim=1)
input_tower = torch.cat((embed_user_features, embed_item_features), dim=1)
cvr_logit = self.tower_cvr(input_tower)
ctr_logit = self.tower_ctr(input_tower)
cvr_pred = torch.sigmoid(cvr_logit)
ctr_pred = torch.sigmoid(ctr_logit)
# Calculation pCTCVR = pCTR * pCVR
ctcvr_pred = torch.mul(cvr_pred, cvr_pred)
ys = [cvr_pred, ctr_pred, ctcvr_pred]
return torch.cat(ys, dim=1)
2. MMOE
2.1 MMOE The background
- Multitask model : Learn commonalities and differences between different tasks , It can improve the quality and efficiency of modeling .
- Traditional multitasking model (Shared Bottom), As tasks become less relevant , The worse the model predicts .

- Hard Parameter Sharing Method : The bottom layer is the shared hidden layer , Learn the common pattern of each task , The upper layer uses some specific full connection layers to learn specific task patterns .
- Soft Parameter Sharing Method : Underlying does not use shared shared bottom, It's more than one tower, To different tower Assign different weights .
- Task sequence dependency modeling : This is suitable for different tasks with a certain sequence dependency .
2.2 MMOE Model framework
mmoe The idea is , At the bottom embedding There are multiple... On the layer expert The Internet , Different task You can choose... According to your needs expert, adopt expert The probability weighted average of the output is obtained for each task The input of .
2.2.1 MOE Model
- Model principle : Based on multiple
ExpertAggregate output , Through the gating network mechanism, eachExpertThe weight of , In different tasks Expert Different weights . - characteristic : Model integration 、 Attention mechanism 、multi-head Mechanism
The following figure shows a two-way data parallel system with three experts MoE How the model performs forward calculation .

2.2.2 MMOE Model

Model architecture : be based on OMOE Model , Every Expert Tasks have a gated network , Multiple tasks, one network .
characteristic :1. Avoid task conflicts , Adjust according to different door controls , Select those that are helpful for the current task Expert Combine ;2. Build relationships between tasks , Flexible parameter sharing ;3. The model can converge quickly during training
2.3 Code
import torch
import torch.nn as nn
from torch_rechub.basic.layers import MLP, Embedding
class MMOE(torch.nn.Module):
def __init__(self, features, task_types, n_expert, expert_params, tower_params_list):
super().__init__()
self.features = features
self.task_types = task_types
# Number of tasks
self.n_task = len(task_types)
self.n_expert = n_expert
self.embedding = EmbeddingLayer(features)
self.input_dims = sum([fea.embed_dim for fea in features])
# Every Expert Corresponding to a door control
self.experts = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **expert_params) for i in range(self.n_expert))
self.gates = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **{
"dims": [self.n_expert],
"activation": "softmax"
}) for i in range(self.n_task))
# Two towers
self.towers = nn.ModuleList(MLP(expert_params["dims"][-1], **tower_params_list[i]) for i in range(self.n_task))
self.predict_layers = nn.ModuleList(PredictionLayer(task_type) for task_type in task_types)
def forward(self, x):
embed_x = self.embedding(x, self.features, squeeze_dim=True)
expert_outs = [expert(embed_x).unsqueeze(1) for expert in self.experts]
expert_outs = torch.cat(expert_outs, dim=1)
gate_outs = [gate(embed_x).unsqueeze(-1) for gate in self.gates]
ys = []
for gate_out, tower, predict_layer in zip(gate_outs, self.towers, self.predict_layers):
expert_weight = torch.mul(gate_out, expert_outs)
expert_pooling = torch.sum(expert_weight, dim=1)
# Calculation of Twin Towers
tower_out = tower(expert_pooling)
# logit -> proba
y = predict_layer(tower_out)
ys.append(y)
return torch.cat(ys, dim=1)
3 summary
- ESMM Model : Mainly introduce CTR and CTCVR The auxiliary task of , stay Complete sample space modeling Training , Avoided Tradition CVR Models often encounter Sample selection bias and training data sparsity . Different models of the two towers can be set according to their own characteristics , The subnetwork supports any replacement .
- MMOE Model : Multiple task sharing experts , Each gating mechanism controls experts' comments on each task , This can not only improve the problems of traditional models ( There are great differences between tasks , The model training does not converge ), You can also train fast convergence .
Reference resources
https://blog.csdn.net/cuixian123456/article/details/118682085
https://blog.csdn.net/m0_37870649/article/details/87378906
https://zhuanlan.zhihu.com/p/454726579
https://www.jianshu.com/p/0f3e40bfd3ceutm_campaign=haruki&utm_content=note&utm_medium=seo_notes&utm_source=recommendation
https://zhuanlan.zhihu.com/p/263781995
边栏推荐
- Functional analysis of ebpf sockops
- Golang daily question
- OSI notes sorting
- Debugging Analysis of Kernel panics and Kernel oopses using System Map
- Geek University cloud native training camp
- Rip/ospf protocol notes sorting
- 装修首页自定义全屏视频播放效果gif动态图片制作视频教程播放代码操作设置全屏居中阿里巴巴国际站
- EditText 控制软键盘出现 搜索
- Wireshark packet capturing skills summarized by myself
- Summary of message protocol problems
猜你喜欢

Curl command

一文理解OpenStack网络

Failed to open after installing Charles without any prompt

Web automation: summary of special scenario processing methods

memcached全面剖析–5. memcached的应用和兼容程序

BPF_ PROG_ TYPE_ SOCKET_ Filter function implementation

Blender's landscape

XTransfer技术新人进阶秘诀:不可错过的宝藏Mentor

Static routing job supplement

Oauth1.0 introduction
随机推荐
Page replacement of virtual memory paging mechanism
PHP script calls command to get real-time output
Adding subscribers to a list using mailchimp's API V3
大厂出海,败于“姿态”
Analysis of BBR congestion control state machine
Memcached comprehensive analysis – 2 Understand memcached memory storage
Network layer
(to be optimized and modified) vivado DDR4 SDRAM (MIG) (2.2) IP core learning record
GDB debugging
C语言实现DNS请求器
Appium desktop introduction
Role of wait function
Am, FM, PM modulation technology
Functional analysis of ebpf tracepoint
Dynamic routing protocol rip, OSPF
123. the best time to buy and sell shares III
Poj1061 frog dating (extended Euclid)
Foundations of Cryptography
Ebpf XDP mount point analysis
PIXIV Gizmo