当前位置:网站首页>Numpy数组广播规则记忆方法 array broadcast 广播原理 广播机制
Numpy数组广播规则记忆方法 array broadcast 广播原理 广播机制
2022-07-24 05:20:00 【ML--小小白】
本文重点在文字描述部分,代码看看就好,如果我的这种方法对您有点点帮助,麻烦点个小赞,如果有更好的方法,或发现我的错误,请不吝赐教
首先,其实数组与标量间的运算其实是一种先广播,后element-wise的运算
import numpy as np
arr = np.arange(5)
arr
array([0, 1, 2, 3, 4])
arr * 4
array([ 0, 4, 8, 12, 16])
广播的规则是每个末尾维度,轴长匹配或者长度是1,广播会在丢失的轴,比如(4, 3) + (3,)后面的数组就复合末尾轴长相等,会广播一个缺失的轴;或者,广播在轴长为1的轴上进行,比如(4, 3) + (1, 3),会将轴0从1广播为4。对于数组和标量的运算,其实也是利用了广播,比如(4, 3) + scale,其中scale的shape其实可以认为(1,),则末尾长度为1,广播时,末尾的维度广播成3,缺失的轴广播成4。
广播可在两个数组中都进行,比如(4, 4) + (4, 1, 4),首先末尾轴长一致,其次不一致的轴长其中一个为1,那么1广播成4,另外缺失轴广播为4.
基于这种规则,有时候想计算(4, 3) + (4, 1),而实际上后者为(4,)的时候,由于末尾轴长不是1,而且3与4也不匹配,因此不能够广播,必须通过reshape,或者[:, None]的方式增加坐标轴,或者利用np.newaxis
因此,其实抓住两个数组的末尾轴长是关键,一看轴长既不是1,也不一致,那么别想广播了,看看怎么写循环操作吧。
arr = np.random.randn(4, 3)
arr.mean(0)
array([ 0.27783846, 0.36009253, -0.1499029 ])
demeaned = arr - arr.mean(0)
demeaned
array([[-0.79969385, -1.6011334 , -0.00747013],
[-0.0381061 , 0.64865496, -0.97992594],
[ 1.13694786, 0.81091045, 0.73967573],
[-0.29914791, 0.14156799, 0.24772034]])
arr.shape, arr.mean(0).shape
((4, 3), (3,))
aaa = np.array([1])
aaa.shape
(1,)
ans = arr - aaa
arr - ans
array([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
arr.shape, aaa.shape
((4, 3), (1,))
arr.shape
(4, 3)
arr.mean(1).shape
(4,)
arr - arr.mean(1)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-18-8b8ada26fac0> in <module>
----> 1 arr - arr.mean(1)
ValueError: operands could not be broadcast together with shapes (4,3) (4,)
arr - arr.mean(1).reshape(-1, 1)
array([[ 0.11823437, -0.6009511 , 0.48271673],
[ 0.20018202, 0.96919716, -1.16937918],
[ 0.35626561, 0.11248227, -0.46874788],
[-0.21403229, 0.30893769, -0.0949054 ]])
arr - arr.mean(1)[:, None]
array([[ 0.11823437, -0.6009511 , 0.48271673],
[ 0.20018202, 0.96919716, -1.16937918],
[ 0.35626561, 0.11248227, -0.46874788],
[-0.21403229, 0.30893769, -0.0949054 ]])
arr - arr.mean(1)[:, np.newaxis]
array([[ 0.11823437, -0.6009511 , 0.48271673],
[ 0.20018202, 0.96919716, -1.16937918],
[ 0.35626561, 0.11248227, -0.46874788],
[-0.21403229, 0.30893769, -0.0949054 ]])
一个三维例子:
arr = np.ones((4, 4))
arr_3d = arr[:, np.newaxis, :]
arr.shape, arr_3d.shape
((4, 4), (4, 1, 4))
arr + arr_3d
array([[[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]],
[[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]],
[[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]],
[[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]]])
一个常用的模式,比如,减掉/除掉某个轴的求和/方差/均值之类的:
arr = np.random.randn(3, 4, 5)
# 加入叫减掉1轴的均值
means = arr.mean(1)
means
array([[-0.95808939, -0.59395877, 0.44605451, 0.06325242, 0.14369531],
[ 0.2600657 , -0.92595688, -0.75528343, -0.2486933 , -0.02936524],
[-0.22052564, 0.14549496, -0.67660057, -0.10151047, 0.26275483]])
arr.shape, means.shape
((3, 4, 5), (3, 5))
demeaned = arr - means[:, np.newaxis, :]
demeaned.mean(1) < 1e-16
array([[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]])
# 可以将其写为一个函数
def demean_axis(arr, axis=0):
means = arr.mean(axis)
indexer = [slice(None)] * arr.ndim
indexer[axis] = np.newaxis
return arr - means[indexer]
arr = np.random.randn(3, 4, 5)
demeaned = demean_axis(arr, axis=1)
demeaned.mean(1) < 1e-16
<ipython-input-45-8051ed80feee>:6: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
return arr - means[(indexer)]
array([[ True, True, True, True, True],
[ True, True, True, True, True],
[ True, True, True, True, True]])
数组赋值其实也用到了广播:
arr = np.zeros((4, 3))
col = np.array([1.28, 0, 33, 0.5])
arr[:] = col[:, np.newaxis]
arr
array([[ 1.28, 1.28, 1.28],
[ 0. , 0. , 0. ],
[33. , 33. , 33. ],
[ 0.5 , 0.5 , 0.5 ]])
arr[:2] = [[2], [3]]
arr
array([[ 2. , 2. , 2. ],
[ 3. , 3. , 3. ],
[33. , 33. , 33. ],
[ 0.5, 0.5, 0.5]])
边栏推荐
- There are three ways to create in Polkadot - parallel chain, parallel thread, and smart contract
- Logic development analysis of LP dual currency liquidity pledge mining system
- 达梦数据库_DISQL下各种连接数据库的方法和执行SQL、脚本的方法
- [Baidu map API] the version of the map JS API you are using is too low and no longer maintained. In order to ensure the normal use of the basic functions of the map, please upgrade to the latest versi
- Substrate technology and ecology June memorabilia | Polkadot decoded came to a successful conclusion, and the hacker song winning project injected new forces into the ecosystem
- 达梦数据库_LENGTH_IN_CHAR和CHARSET的影响情况
- Flink Task、Sub-Task、task slot和parallelism
- 多商户商城系统功能拆解09讲-平台端商品品牌
- Inventory Poka ecological potential project | cross chain characteristics to promote the prosperity of multi track
- Principle of fusdt liquidity pledge mining development logic system
猜你喜欢

likeshop单商户SAAS商城系统无限多开

Flink task, sub task, task slot and parallelism

【mycat】mycat相关概念

Flink Watermark机制

Multi merchant mall system function disassembly Lecture 11 - platform side commodity column

Multi merchant mall system function disassembly lecture 13 - platform side member management

Recommend a fully open source, feature rich, beautiful interface mall system

多商户商城系统功能拆解04讲-平台端商家入驻

Flink format series (1) -json

Introduction to PC mall module of e-commerce system
随机推荐
多商户商城系统功能拆解05讲-平台端商家主营类目
Sunset: noontide target penetration vulnhub
《机器学习》(周志华) 第5章 神经网络 学习心得 笔记
Subsystem technology and ecology may memorabilia | square one plan launched, Boca launched xcm!
Wechat applet reports an error request:fail -2:net:: err_ FAILED
多商户商城系统功能拆解06讲-平台端商家入驻协议
Flink Format系列(1)-JSON
统计信号处理小作业——瑞利分布噪声中确定性直流信号的检测
There are three ways to create in Polkadot - parallel chain, parallel thread, and smart contract
Mysqldump export Chinese garbled code
公众号开发自定义菜单和服务器配置同时启用
Canal+kafka实战(监听mysql binlog实现数据同步)
SqlServer 完全删除
The SaaS mall system of likeshop single merchant is built, and the code is open source without encryption.
Oracle数据库的逻辑结构
【mycat】mycat相关概念
多商户商城系统功能拆解07讲-平台端商品管理
[vSphere high availability] virtual machine reset and restart
Flink 并行度的理解(parallel)
SSM项目配置中问题,各种依赖等(个人使用)