当前位置:网站首页>Lstms in tensorflow_ Cell actual combat
Lstms in tensorflow_ Cell actual combat
2022-06-26 05:04:00 【Rain and dew touch the real king】
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
batchsz = 128
# the most frequest words
total_words = 10000
max_review_len = 80
embedding_len = 100
(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=total_words)
# x_train:[b, 80]
# x_test: [b, 80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=max_review_len)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=max_review_len)
db_train = tf.data.Dataset.from_tensor_slices((x_train, y_train))
db_train = db_train.shuffle(1000).batch(batchsz, drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_test, y_test))
db_test = db_test.batch(batchsz, drop_remainder=True)
print('x_train shape:', x_train.shape, tf.reduce_max(y_train), tf.reduce_min(y_train))
print('x_test shape:', x_test.shape)
class MyRNN(keras.Model):
def __init__(self, units):
super(MyRNN, self).__init__()
# [b, 64]
self.state0 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
self.state1 = [tf.zeros([batchsz, units]),tf.zeros([batchsz, units])]
# transform text to embedding representation
# [b, 80] => [b, 80, 100]
self.embedding = layers.Embedding(total_words, embedding_len,
input_length=max_review_len)
# [b, 80, 100] , h_dim: 64
# RNN: cell1 ,cell2, cell3
# SimpleRNN
# self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.5)
# self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
self.rnn_cell0 = layers.LSTMCell(units, dropout=0.5)
self.rnn_cell1 = layers.LSTMCell(units, dropout=0.5)
# fc, [b, 80, 100] => [b, 64] => [b, 1]
self.outlayer = layers.Dense(1)
def call(self, inputs, training=None):
"""
net(x) net(x, training=True) :train mode
net(x, training=False): test
:param inputs: [b, 80]
:param training:
:return:
"""
# [b, 80]
x = inputs
# embedding: [b, 80] => [b, 80, 100]
x = self.embedding(x)
# rnn cell compute
# [b, 80, 100] => [b, 64]
state0 = self.state0
state1 = self.state1
for word in tf.unstack(x, axis=1): # word: [b, 100]
# h1 = x*wxh+h0*whh
# out0: [b, 64]
out0, state0 = self.rnn_cell0(word, state0, training)
# out1: [b, 64]
out1, state1 = self.rnn_cell1(out0, state1, training)
# out: [b, 64] => [b, 1]
x = self.outlayer(out1)
# p(y is pos|x)
prob = tf.sigmoid(x)
return prob
def main():
units = 64
epochs = 4
import time
t0 = time.time()
model = MyRNN(units)
model.compile(optimizer = keras.optimizers.Adam(0.001),
loss = tf.losses.BinaryCrossentropy(),
metrics=['accuracy'],experimental_run_tf_function=False)
model.fit(db_train, epochs=epochs, validation_data=db_test)
model.evaluate(db_test)
t1 = time.time()
print('total time cost:', t1-t0)# 64.3 seconds, 83.4%
if __name__ == '__main__':
main()
边栏推荐
- 6.1 - 6.2 公鑰密碼學簡介
- Yolov5 super parameter setting and data enhancement analysis
- 0622-马棕榈跌9%
- Use fill and fill in Matplotlib_ Between fill the blank area between functions
- 【Unity3D】刚体组件Rigidbody
- ThreadPoolExecutor implements file uploading and batch inserting data
- 微服务之间的Token传递之一@Feign的token传递
- 5. < tag stack and general problems > supplement: lt.946 Verify the stack sequence (the same as the push in and pop-up sequence of offer 31. stack)
- LeetCode 19. Delete the penultimate node of the linked list
- GD32F3x0 官方PWM驱动正频宽偏小(定时不准)的问题
猜你喜欢
随机推荐
6.1 - 6.2 公钥密码学简介
A ZABBIX self discovery script (shell Basics)
Genius makers: lone Rangers, technology giants and AI | ten years of the rise of in-depth learning
Rsync common error messages (common errors on the window)
ModuleNotFoundError: No module named ‘numpy‘
2022.2.11
C# 39. string类型和byte[]类型相互转换(实测)
Cookie and session Basics
Comment enregistrer une image dans une applet Wechat
ROS 笔记(07)— 客户端 Client 和服务端 Server 的实现
Final review of brain and cognitive science
Introduction to classification data cotegory and properties and methods of common APIs
Multipass Chinese document - share data with instances
-Discrete Mathematics - Analysis of final exercises
天才制造者:独行侠、科技巨头和AI|深度学习崛起十年
Guanghetong and anti international bring 5g R16 powerful performance to the AI edge computing platform based on NVIDIA Jetson Xavier nx
2022.2.15
A company crawling out of its grave
微信小程序保存图片的方法
为什么许多shopify独立站卖家都在用聊天机器人?一分钟读懂行业秘密!