当前位置:网站首页>NER中BiLSTM-CRF解读Forward_algorithm
NER中BiLSTM-CRF解读Forward_algorithm
2022-06-28 00:41:00 【365JHWZGo】
如果你对下面的内容有疑惑,可能需要看一下我前一篇写的对BiLSTM-CRF的讲解
CRF+BiLSTM代码分步骤解读
讲解
之前也讲过forward_algorithm的是用来求解所有路径得分之和的函数,下面将用一个具体的例子来讲解一下这个函数实现的流程。
先随机初始化一个发射矩阵e_score (batch_size, seq_len, tags_size)

再随机初始化一个发射矩阵t_score (tags_size, tags_size)

创建一个init_matrix,然后再复制一份给pre_matrix,这里为了方便理解将模型竖起来 (batch_size, 1, tags_size)

这里仅展示当时刻为0,状态为’B’时的计算过程

下面都是在log_sum_exp中的临时变量


代码
def forward_algorithm(self, e_matrix):
# matrix 是在当前状态下总路径之和
init_matrix = torch.full((BATCH_SIZE, 1, tags_size), -10000.0)
init_matrix[:, 0, self.s2i[START_TAG]] = 0.
# 前一步的最优值
pre_matrix = init_matrix
# 循环时间
for i in range(SEQ_LEN):
# 保存当前时间步的的路径值
matrix_value = []
# 循环状态
for s in range(tags_size):
# 计算发射分数, (BATCH_SIZE, 1, tags_size)
e_score = e_matrix[:, i, s].view(BATCH_SIZE, 1, -1).expand(BATCH_SIZE, 1, tags_size)
# 计算转移分数 (1,tags_size)
t_score = self.t_score[s, :].view(1, -1)
# 下一步的得分 (BATCH_SIZE, 1, tags_size)
next_matrix = pre_matrix + e_score + t_score
# self.log_sum_exp(next_matrix) (BATCH_SIZE, 1)
matrix_value.append(self.log_sum_exp(next_matrix))
# 在把其记录到pre_matrix变量中
pre_matrix = torch.cat(matrix_value, dim=-1).view(BATCH_SIZE, 1, -1)
# 最终的变量:之前的得分+转移到终点的得分 (BATCH_SIZE, 1, tags_size)
terminal_var = pre_matrix + self.t_score[self.s2i[STOP_TAG], :]
alpha = self.log_sum_exp(terminal_var)
# (BATCH_SIZE,1)
return alpha
可看到这里也不太明白为什么这样做可以得到所有路径之和,其实,这样做无非是为了简化运算,但这样计算的不足在于使用了很多遍logsumexp,这就和原先的值其实有一些差距。
理想值
s c o r e i d e a l = l o g ( e S 1 + e S 2 + . . . + e S N ) (1) score_{ideal} = log(e^{S_1}+e^{S_2}+...+e^{S_N})\tag1 scoreideal=log(eS1+eS2+...+eSN)(1)
现实值
s c o r e r e a l i t y = l o g ( ∑ e p r e + t ) = l o g ( ∑ e l o g ( ∑ e p r e + t + e s ) + t ) = . . . (2) \begin{aligned} score_{reality} &= log(\sum e^{pre+t})\\ &= log(\sum e^{log(\sum e^{pre+t+es})+t})\\ &=... \end{aligned}\tag2 scorereality=log(∑epre+t)=log(∑elog(∑epre+t+es)+t)=...(2)
t->t_score
es->e_score
pre->pre_matrix

如上图所示,*球处已经计算了从<START>到"我",前一步所有状态到B的全部路径得分S1,求logsumexp(S1)记录到*球处,同理球处则是前两步所有路径到达"爱",并且所有状态转移至B的全部路径得分S2,求logsumexp(S2)记录到球处。
至此,你学废了吗?
边栏推荐
猜你喜欢

New choice for database Amazon Aurora

Skills in schematic merging

High reliability application knowledge map of Architecture -- the path of architecture evolution

MySQL collection, here are all the contents you want

Use code binding DataGridView control to display tables in program interface

SQL injection bypass (IV)

Jenkins - Pipeline syntax

Dynamic Host Configuration Protocol

【历史上的今天】6 月 1 日:Napster 成立;MS-DOS 原作者出生;谷歌出售 Google SketchUp

关于st-link usb communication error的解决方法
随机推荐
如何系统学习LabVIEW?
【历史上的今天】6 月 5 日:洛夫莱斯和巴贝奇相遇;公钥密码学先驱诞生;函数语言设计先驱出生
flask基础:模板继承+静态文件配置
Cesium color color (assignment) random color
KVM相关
Graduation summary
技术人员如何成为技术领域专家
后勤事务繁杂低效?三步骤解决企业行政管理难题
Low code solution - a low code solution for digital after-sales service covering the whole process of work order, maintenance and Finance
启牛开户安全吗?怎么线上开户?
Flask基础:模板渲染+模板过滤使用+控制语句
To understand what is synchronous, asynchronous, serial, parallel, concurrent, process, thread, and coroutine
SQL injection bypass (3)
TD Hero online conference on July 2
Jenkins - access the Jenkins user-defined parameter variable, and handle the variable value containing spaces
[JS reverse hundreds of examples] I love to solve 2022 Spring Festival problems and receive red envelopes
Cesium Click to obtain longitude and latitude (2D coordinates)
Interpretation of the source code of scheduledthreadpoolexecutor (II)
设计电商秒杀系统
Wangxinling, tanweiwei Shanhai (extended version of Chorus) online audition lossless FLAC Download