当前位置:网站首页>Sentimentin tensorflow_ analysis_ cell
Sentimentin tensorflow_ analysis_ cell
2022-06-26 05:04:00 【Rain and dew touch the real king】
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
batchsz =128
#the most frequest words
total_words=10000
max_review_len=80
embeding_len= 100
(x_train,y_train),(x_test,y_test) = keras.datasets.imdb.load_data(num_words=total_words)
#x_train:[b,80]
#x_test:[b,80]
x_train = keras.preprocessing.sequence.pad_sequences(x_train,maxlen=max_review_len)
x_test=keras.preprocessing.sequence.pad_sequences(x_test,maxlen=max_review_len)
db_train = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db_train=db_train.shuffle(1000).batch(batchsz,drop_remainder=True)
db_test = tf.data.Dataset.from_tensor_slices((x_train,y_train))
db_test=db_test.batch(batchsz,drop_remainder=True)
print('x_train shapeL:',x_train.shape,tf.reduce_max(y_train),tf.reduce_min(y_train))
print('x_test shape',x_test.shape)
class MyRNN(keras.Model):
def __init__(self,units):
super(MyRNN,self).__init__()
# [b, 64]
self.state0 = [tf.zeros([batchsz, units])]
self.state1 = [tf.zeros([batchsz, units])]
#transform text to embedding representation
#[b,80]=>[b,80,100]
self.embedding = layers.Embedding(total_words,embeding_len,
input_length=max_review_len)
#[b,80,100],h_dim:64
#RNN:cell1,cell2,cell3
#SimpleRNN
self.rnn_cell0 = layers.SimpleRNNCell(units, dropout=0.5)
self.rnn_cell1 = layers.SimpleRNNCell(units, dropout=0.5)
#fc,[b,80,100]=>[b,64]=>[b,1]
self.outlayer= layers.Dense(1)
def call(self,inputs,training=None):
"""
net(x) net(x,training=True):train mode
net(x,training=False):test
:param inputs: [b,80]
:param training:
:return:
"""
#[b,80]
x = inputs
#embedding: [b,80]=>[b,80,100]
x = self.embedding(x)
#rnn cell compute
#[b,80,100]=>[b,64]
state0 = self.state0
state1=self.state1
for word in tf.unstack(x , axis=1):#word:[b,100]
#h1 = x*wxh+h0*whh
#out0:[b,64]
out0 , state0 = self.rnn_cell0(word,state0,training)
#out1:[b,64]
out1, state1=self.rnn_cell1(out0,state1)
#out:[b,64]=>[b,1]
x = self.outlayer(out1)
#p(y is pos|x)
prob = tf.sigmoid(x)
return prob
def main():
units = 64
epochs = 4
model = MyRNN(units)
model.compile(optimizer = keras.optimizers.Adam(0.001),
loss = tf.losses.BinaryCrossentropy(),
metrics=['accuracy'],experimental_run_tf_function=False)
model.fit(db_train, epochs=epochs, validation_data=db_test)
model.evaluate(db_test)
if __name__ == '__main__':
main()
边栏推荐
- Stm8 MCU ADC sampling function is triggered by timer
- 2022.1.23
- 文件上传与安全狗
- Machine learning final exercises
- File upload and security dog
- Record a circular reference problem
- Zhongshanshan: engineers after being blasted will take off | ONEFLOW u
- ModuleNotFoundError: No module named ‘numpy‘
- Wechat applet exits the applet (navigator and api--wx.exitminiprogram)
- Selection of programming language
猜你喜欢
Genius makers: lone Rangers, technology giants and AI | ten years of the rise of in-depth learning
Illustration of ONEFLOW's learning rate adjustment strategy
ModuleNotFoundError: No module named ‘numpy‘
PowerShell runtime system IO exceptions
Zhongshanshan: engineers after being blasted will take off | ONEFLOW u
6.1 - 6.2 introduction to public key cryptography
5. <tag-栈和常规问题>补充: lt.946. 验证栈序列(同剑指 Offer 31. 栈的压入、弹出序列)
torchvision_transform(图像增强)
[latex] error type summary (hold the change)
5. < tag stack and general problems > supplement: lt.946 Verify the stack sequence (the same as the push in and pop-up sequence of offer 31. stack)
随机推荐
date_ Range creation date range freq parameter value table and creation example
UWB ultra high precision positioning system architecture
Zuul 实现动态路由
22.2.8
A ZABBIX self discovery script (shell Basics)
Comment enregistrer une image dans une applet Wechat
C# 39. string类型和byte[]类型相互转换(实测)
2. < tag dynamic programming and conventional problems > lt.343 integer partition
Day3 data type and Operator jobs
[geek] product manager training camp
Multipass中文文档-使用Multipass服务授权客户端
ROS 笔记(07)— 客户端 Client 和服务端 Server 的实现
YOLOv5-6.0的一些参数设置和特征图可视化
torchvision_transform(图像增强)
Collections and dictionaries
Numpy data input / output
Using Matplotlib to add an external image at the canvas level
Resample
Multipass Chinese document - remote use of multipass
JWT token authentication verification