当前位置:网站首页>Bert-whitening 向量降维及使用
Bert-whitening 向量降维及使用
2022-06-24 13:04:00 【loong_XL】
参考:https://kexue.fm/archives/8069
https://kexue.fm/archives/9079
https://zhuanlan.zhihu.com/p/531476789
输入:vv是多个向量组成的三维矩阵
结果:v_data1 256维度
def compute_kernel_bias(vecs, n_components=256):
"""计算kernel和bias
vecs.shape = [num_samples, embedding_size],
最后的变换:y = (x + bias).dot(kernel)
"""
mu = vecs.mean(axis=0, keepdims=True)
cov = np.cov(vecs.T)
# print(cov)
u, s, vh = np.linalg.svd(cov)
print(np.diag(1 / np.sqrt(s) ))
W = np.dot(u, np.diag(1 / np.sqrt(s)))
return W[:, :n_components], -mu
def transform_and_normalize(vecs, kernel=None, bias=None):
""" 最终向量标准化
"""
if not (kernel is None or bias is None):
vecs = (vecs + bias).dot(kernel)
return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5
v_data = np.array(vv[0]) ## vv[0]多个向量组成的二维矩阵,如果输入一个向量的二维矩阵计算会报错
kernel,bias=compute_kernel_bias(v_data)
# print(kernel,bias)
v_data1=transform_and_normalize(v_data, kernel=kernel, bias=bias)
***线上单个向量就把上面整体计算出的kernel,bias用上,直接transform_and_normalize(v_data, kernel=kernel, bias=bias)就行
import numpy as np
data = np.random.rand(5,768)
print('data.shape = ')
print(data.shape,data)
def compute_kernel_bias(vecs):
"""计算kernel和bias
vecs.shape = [num_samples, embedding_size],
最后的变换:y = (x + bias).dot(kernel)
"""
mu = vecs.mean(axis=0, keepdims=True)
cov = np.cov(vecs.T)
u, s, vh = np.linalg.svd(cov)
W = np.dot(u, np.diag(1 / np.sqrt(s)))
return W, -mu
def transform_and_normalize(vecs, kernel=None, bias=None):
"""应用变换,然后标准化
"""
if not (kernel is None or bias is None):
vecs = (vecs + bias).dot(kernel)
return vecs / (vecs**2).sum(axis=1, keepdims=True)**0.5
kernel,bias = compute_kernel_bias(data)
kernel = kernel[:,:64]
print('kernel.shape = ')
print(kernel.shape)
print('bias.shape = ')
print(bias.shape)
data = transform_and_normalize(data, kernel, bias)
print('data.shape = ')
print(data.shape,data)

线上单个向量降维
data1 = np.random.rand(1,768)
data1_1 = transform_and_normalize(data1, kernel, bias)

边栏推荐
- SSH keygen configuration does not require entering a password every time
- STM32F1与STM32CubeIDE编程实例-WS2812B全彩LED驱动(基于SPI+DMA)
- pgsql查询分组中某个字段最大或者最小的一条数据
- 六石管理学:垃圾场效应:工作不管理,就会变成垃圾场
- Return to new list
- 百度地图API绘制点及提示信息
- leetcode.12 --- 整数转罗马数字
- Common sense knowledge points
- Redis interview questions
- Method of inputting dots under letters in markdown/latex
猜你喜欢

Win10 system problems

卷积核、特征图可视化

Method of inputting dots under letters in markdown/latex

【深度学习】NCHW、NHWC和CHWN格式数据的存储形式

简谈企业Power BI CI /CD 实施框架

数字臧品系统开发 NFT数字臧品系统异常处理源码分享

融云通信“三板斧”,“砍”到了银行的心坎上

puzzle(016.2)指画星河

Rasa 3. X learning series - it is a great honor to be a source code contributor of Rasa contributors, and to build and share the rasa community with rasa source code contributors all over the world!

How to avoid placing duplicate orders
随机推荐
leetcode.12 --- 整数转罗马数字
box-sizing
百度地图API绘制点及提示信息
简谈企业Power BI CI /CD 实施框架
10_ Those high-profile personal signatures
Go language concurrency model mpg model
MES在流程和离散制造企业的15个差别(下)
[pytoch] quantification
文本对比学习综述
postgresql之List
R语言plotly可视化:使用plotly可视化数据划分后的训练集和测试集、使用不同的形状标签表征、训练集、测试集、以及数据集的分类标签(Display training and test split
Research on MySQL composite index
10_那些格調很高的個性簽名
C language ---18 function (user-defined function)
厨卫电器行业B2B交易协同管理平台开发,优化企业库存结构
Getting to know cloud native security for the first time: the best guarantee in the cloud Era
IDEA连接mysql自定义生成实体类代码
Go language - use of goroutine coroutine
Go language -init() function - package initialization
Common singleton mode & simple factory