当前位置:网站首页>Multi task model of recommended model: esmm, MMOE
Multi task model of recommended model: esmm, MMOE
2022-06-24 21:34:00 【Levi Bebe】
1. ESMM
ESMM The full name is Entire Space Multi-task Model (ESMM), It is a multi task training method proposed by Alibaba algorithm team . In information retrieval 、 Recommendation system 、 Online advertising system CTR、CVR Widely used in estimation . Take e-commerce recommendation system as an example , Maximize the total transaction amount of scenario commodities (GMV) Is one of the important goals of the platform , and GMV Can be disassembled into flow × Click through rate × Conversion rate × Customer unit price , Therefore, the conversion rate is one of the important factors of the optimization objective ; From the perspective of user experience, conversion rate can be used to balance users' click preference and purchase preference .
Conventional CVR Estimated mission , It has the following problems :
- Sample selection bias : The distribution sampling of the constructed training sample set deviates from the actual data distribution ;
- Sparse data : Click samples account for a small proportion of exposure samples

1.2 ESMM principle
ESMM Model Using user behavior sequence data in Complete sample space modeling , Avoided Tradition CVR Sample selection that models often encounter Bias and sparse training data , Achieved remarkable results . On the other hand ,ESMM The model first proposed using learning CTR and CTCVR The auxiliary task of detour learning CVR The idea of
1.2.1 ESMM Model framework

CTR = Actual hits / The amount of display
CVR = Number of conversion / Clicks
CTCVR = Conversion number / Exposure . It's prediction item Clicked , Then the probability of being transformed .
ESMM The model consists of two subnetworks :
- The subnetwork on the left To fit pCVR
- The subnetwork on the right To fit pCTR, meanwhile , After multiplying the outputs of the two subnetworks, we can get pCTCVR. therefore , The network structure has three sub tasks , Used to output pCTR、pCVR and pCTCVR.

among x It means exposure ,y It means to click ,z It means transformation
Be careful :
- share Embedding. CVR-task and CTR-task Use the same features and features embedding, That is, both from Concatenate Then learn their own exclusive parameters ;
- Implicit learning pCVR. here pCVR Just one of the networks variable, There is no monitoring signal shown .
CTCVR and CTR Of label Constructing loss function :
Resolve sample selection bias : In the process of training , The model only needs to predict pCTCVR and pCTR, Parameters can be updated , because pCTCVR and pCTR The data is extracted based on the complete sample space , So according to the formula , Can solve pCVR Sample selection deviation .
Solve the problem of data sparsity : Use shared embedding layer , bring CVR Subtasks can also learn from samples that show only clicks , It can alleviate the problem of sparse training data .
1.3 ESMM Model optimization
- Model optimization : In the paper , Subtasks are independent Tower The network is pure MLP Model , You can set different models according to your own characteristics , For example, using DeepFM、DIN etc.
- Learning optimization : Introduce dynamic weighted learning mechanism , Optimize loss
- Feature optimization : Longer sequence dependency models can be built , For example, meituan AITM Credit card business , The user conversion process is exposure -> Click on -> apply -> Nuclear card -> Activate
1.4 ESMM Code
import torch
import torch.nn.functional as F
from torch_rechub.basic.layers import MLP, EmbeddingLayer
from tqdm import tqdm
class ESMM(torch.nn.Module):
def __init__(self, user_features, item_features, cvr_params, ctr_params):
super().__init__()
self.user_features = user_features
self.item_features = item_features
self.embedding = EmbeddingLayer(user_features + item_features)
self.tower_dims = user_features[0].embed_dim + item_features[0].embed_dim
# structure CVR and CTR The twin towers
self.tower_cvr = MLP(self.tower_dims, **cvr_params)
self.tower_ctr = MLP(self.tower_dims, **ctr_params)
def forward(self, x):
embed_user_features = self.embedding(x, self.user_features,
squeeze_dim=False).sum(dim=1)
embed_item_features = self.embedding(x, self.item_features,
squeeze_dim=False).sum(dim=1)
input_tower = torch.cat((embed_user_features, embed_item_features), dim=1)
cvr_logit = self.tower_cvr(input_tower)
ctr_logit = self.tower_ctr(input_tower)
cvr_pred = torch.sigmoid(cvr_logit)
ctr_pred = torch.sigmoid(ctr_logit)
# Calculation pCTCVR = pCTR * pCVR
ctcvr_pred = torch.mul(cvr_pred, cvr_pred)
ys = [cvr_pred, ctr_pred, ctcvr_pred]
return torch.cat(ys, dim=1)
2. MMOE
2.1 MMOE The background
- Multitask model : Learn commonalities and differences between different tasks , It can improve the quality and efficiency of modeling .
- Traditional multitasking model (Shared Bottom), As tasks become less relevant , The worse the model predicts .

- Hard Parameter Sharing Method : The bottom layer is the shared hidden layer , Learn the common pattern of each task , The upper layer uses some specific full connection layers to learn specific task patterns .
- Soft Parameter Sharing Method : Underlying does not use shared shared bottom, It's more than one tower, To different tower Assign different weights .
- Task sequence dependency modeling : This is suitable for different tasks with a certain sequence dependency .
2.2 MMOE Model framework
mmoe The idea is , At the bottom embedding There are multiple... On the layer expert The Internet , Different task You can choose... According to your needs expert, adopt expert The probability weighted average of the output is obtained for each task The input of .
2.2.1 MOE Model
- Model principle : Based on multiple
ExpertAggregate output , Through the gating network mechanism, eachExpertThe weight of , In different tasks Expert Different weights . - characteristic : Model integration 、 Attention mechanism 、multi-head Mechanism
The following figure shows a two-way data parallel system with three experts MoE How the model performs forward calculation .

2.2.2 MMOE Model

Model architecture : be based on OMOE Model , Every Expert Tasks have a gated network , Multiple tasks, one network .
characteristic :1. Avoid task conflicts , Adjust according to different door controls , Select those that are helpful for the current task Expert Combine ;2. Build relationships between tasks , Flexible parameter sharing ;3. The model can converge quickly during training
2.3 Code
import torch
import torch.nn as nn
from torch_rechub.basic.layers import MLP, Embedding
class MMOE(torch.nn.Module):
def __init__(self, features, task_types, n_expert, expert_params, tower_params_list):
super().__init__()
self.features = features
self.task_types = task_types
# Number of tasks
self.n_task = len(task_types)
self.n_expert = n_expert
self.embedding = EmbeddingLayer(features)
self.input_dims = sum([fea.embed_dim for fea in features])
# Every Expert Corresponding to a door control
self.experts = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **expert_params) for i in range(self.n_expert))
self.gates = nn.ModuleList(
MLP(self.input_dims, output_layer=False, **{
"dims": [self.n_expert],
"activation": "softmax"
}) for i in range(self.n_task))
# Two towers
self.towers = nn.ModuleList(MLP(expert_params["dims"][-1], **tower_params_list[i]) for i in range(self.n_task))
self.predict_layers = nn.ModuleList(PredictionLayer(task_type) for task_type in task_types)
def forward(self, x):
embed_x = self.embedding(x, self.features, squeeze_dim=True)
expert_outs = [expert(embed_x).unsqueeze(1) for expert in self.experts]
expert_outs = torch.cat(expert_outs, dim=1)
gate_outs = [gate(embed_x).unsqueeze(-1) for gate in self.gates]
ys = []
for gate_out, tower, predict_layer in zip(gate_outs, self.towers, self.predict_layers):
expert_weight = torch.mul(gate_out, expert_outs)
expert_pooling = torch.sum(expert_weight, dim=1)
# Calculation of Twin Towers
tower_out = tower(expert_pooling)
# logit -> proba
y = predict_layer(tower_out)
ys.append(y)
return torch.cat(ys, dim=1)
3 summary
- ESMM Model : Mainly introduce CTR and CTCVR The auxiliary task of , stay Complete sample space modeling Training , Avoided Tradition CVR Models often encounter Sample selection bias and training data sparsity . Different models of the two towers can be set according to their own characteristics , The subnetwork supports any replacement .
- MMOE Model : Multiple task sharing experts , Each gating mechanism controls experts' comments on each task , This can not only improve the problems of traditional models ( There are great differences between tasks , The model training does not converge ), You can also train fast convergence .
Reference resources
https://blog.csdn.net/cuixian123456/article/details/118682085
https://blog.csdn.net/m0_37870649/article/details/87378906
https://zhuanlan.zhihu.com/p/454726579
https://www.jianshu.com/p/0f3e40bfd3ceutm_campaign=haruki&utm_content=note&utm_medium=seo_notes&utm_source=recommendation
https://zhuanlan.zhihu.com/p/263781995
边栏推荐
- Rewrite, maplocal and maplocal operations of Charles
- BBR bandwidth per second conversion logic
- Rip/ospf protocol notes sorting
- Tdengine can read and write through dataX
- Apple mobile phone can see some fun ways to install IPA package
- Pattern recognition - 1 Bayesian decision theory_ P1
- OSI and tcp/ip model
- memcached完全剖析–1. memcached的基础
- Slider控制Animator动画播放进度
- Markdown use
猜你喜欢

Big factories go out to sea and lose "posture"

推荐模型之多任务模型:ESMM、MMOE

Adding subscribers to a list using mailchimp's API V3

Analyse complète Memcached – 2. Comprendre le stockage de mémoire pour Memcached

Alibaba cloud lightweight servers open designated ports

Pattern recognition - 1 Bayesian decision theory_ P1

Rewrite, maplocal and maplocal operations of Charles

JMeter basic learning records

The virtual currency evaporated $2trillion in seven months, and the "musks" ended the dream of 150000 people becoming rich

Curl command
随机推荐
Web automation: summary of special scenario processing methods
TCP_ Nodelay and TCP_ CORK
Wireshark packet capturing skills summarized by myself
XTransfer技术新人进阶秘诀:不可错过的宝藏Mentor
Splicing audio files with ffmpeg-4.3
Curl command
Basic database syntax learning
Call process of package receiving function
一文理解OpenStack网络
Decoration home page custom full screen video playback effect GIF dynamic picture production video tutorial playback code operation settings full screen center Alibaba international station
ping: www.baidu. Com: unknown name or service
Shell script
Appium introduction and environment installation
虚拟货币7个月蒸发2万亿美元,“马斯克们”终结15万人暴富梦
BBR bandwidth per second conversion logic
Self signed certificate generation
Kernel Debugging Tricks
Requests requests for web page garbled code resolution
Analysis of tcpdump packet capturing kernel code
Tso hardware sharding is a header copy problem