当前位置:网站首页>Tensorflow Einstein function

Tensorflow Einstein function

2022-07-24 06:57:00 Live up to your youth

The function prototype

tf.einsum(
    equation, *inputs, **kwargs
)

Function description

Used to implement tensorflow Dot product of tensor in 、 Exoproduct 、 device 、 Matrix multiplication and so on ,einsum Is an expression of these operations , Elegant ways including complex tensor operations , Basically , You can put einsum As a domain specific language . Once you understand and can use einsum, In addition to the benefits of not having to remember and frequently looking for specific library functions , You can also write faster and more compact 、 Efficient code .

such as , Two tensors A、B Multiply , Get the matrix C, The formula is as follows :

 Insert picture description here
With corresponding einsum Expressed as :

ij,jk->ik

Parameters equation Indicates the corresponding einsum Notation , Is a string ; Parameters *inputs Represents multiple tensors of input , Shape should be consistent with equation Corresponding to in .

Function USES

1、 Matrix dot product ,c = sum_i a[i]*b[i]

>>> a = tf.constant([1, 2, 3, 4])
>>> b = tf.constant([2, 2, 2, 2])
>>> tf.einsum("i,i->", a, b)
<tf.Tensor: shape=(), dtype=int32, numpy=20>

2、 Matrix multiplication ,c[i, k] = sum_j a[i, j]*b[j, k]

>>> a = tf.ones([2, 2])
>>> b = tf.ones([2, 2])
>>> tf.einsum("ij,jk->ik", a, b)
<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

3、 Matrix transposition , b[j, i] = a[i, j]

>>> a = tf.constant([[1, 2], [3, 4]])
>>>> a
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 2],
       [3, 4]])>
>>> tf.einsum("ij->ji", a)
<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
array([[1, 3],
       [2, 4]])>

4、 Get diagonal elements of matrix

>>> a = tf.linalg.band_part(tf.ones((3, 3)), 0, 0)
>>> a
<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)>
>>> tf.einsum("ii->i", a)
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 1., 1.], dtype=float32)>
原网站

版权声明
本文为[Live up to your youth]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/205/202207240550104843.html