tensorflow:使用ResNet50网络进行分类

2022-07-01  本文已影响0人  万州客

这个情感分类,也可以使用GRU模型,或是VGG,DPCNN来进行的,代码都差不多。不同的网络结果,训练时间不一样,效果也不一样。

一,代码

import tensorflow as tf
import numpy as np

resnet_layer = tf.keras.applications.ResNet50(include_top=False, weights=None)
labels = []
vocab = set()
context = []

with open('ChnSentiCorp.txt', mode='r', encoding='UTF-8') as emotion_file:
    for line in emotion_file.readlines():
        line = line.strip().split(',')
        labels.append(int(line[0]))

        text = line[1]
        context.append(text)
        for char in text: vocab.add(char)

vocab_list = list(sorted(vocab))
token_list = []
for text in context:
    token = [vocab_list.index(char) for char in text]
    token = token[:80] + [0] * (80 - len(token))
    token_list.append(token)
token_list = np.array(token_list)
labels = np.array(labels)
input_token = tf.keras.Input(shape=(80,))
embedding = tf.keras.layers.Embedding(input_dim=3508, output_dim=128)(input_token)
# embedding = tf.keras.layers.Bidirectional(tf.keras.layers.GRU(128))(embedding)
embedding = tf.tile(tf.expand_dims(embedding, axis=-1), [1, 1, 1, 3])
embedding = resnet_layer(embedding)
embedding = tf.keras.layers.Flatten()(embedding)

output = tf.keras.layers.Dense(2, activation=tf.nn.softmax)(embedding)
model = tf.keras.Model(input_token, output)

model.compile(optimizer='adam', loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['accuracy'])
model.fit(token_list, labels, epochs=10, verbose=2)

input = tf.Variable(tf.random.normal([1, 5, 5, 1]))
conv = tf.keras.layers.Conv2D(1, 2, strides=[2, 2], padding='SAME')(input)
print(conv.shape)

二,输出

C:\Users\ccc\AppData\Local\Programs\Python\Python38\python.exe D:/tmp/tele_churn/tf_test.py
2022-07-01 17:46:55.021709: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX AVX2
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
Epoch 1/10
243/243 - 423s - loss: 2.2180 - accuracy: 0.6021 - 423s/epoch - 2s/step
Epoch 2/10
243/243 - 421s - loss: 0.9366 - accuracy: 0.6394 - 421s/epoch - 2s/step
Epoch 3/10
243/243 - 425s - loss: 0.5442 - accuracy: 0.7800 - 425s/epoch - 2s/step
Epoch 4/10
243/243 - 426s - loss: 0.3573 - accuracy: 0.8612 - 426s/epoch - 2s/step
Epoch 5/10
243/243 - 436s - loss: 0.2643 - accuracy: 0.9077 - 436s/epoch - 2s/step
Epoch 6/10
243/243 - 469s - loss: 0.1155 - accuracy: 0.9556 - 469s/epoch - 2s/step
Epoch 7/10
243/243 - 459s - loss: 0.1040 - accuracy: 0.9648 - 459s/epoch - 2s/step
Epoch 8/10
243/243 - 478s - loss: 0.1056 - accuracy: 0.9642 - 478s/epoch - 2s/step
Epoch 9/10
243/243 - 513s - loss: 0.0929 - accuracy: 0.9695 - 513s/epoch - 2s/step
Epoch 10/10
243/243 - 8216s - loss: 0.0798 - accuracy: 0.9789 - 8216s/epoch - 34s/step
(1, 3, 3, 1)

Process finished with exit code 0

上一篇下一篇

猜你喜欢

热点阅读