tensorflow编程

2020-03-22  本文已影响0人  啊啊啊啊啊1231

load进inception-v3 model,   再把图片给输入模型

# Code modified from https://github.com/tensorflow/cleverhans/blob/master/examples/nips17_adversarial_competition/sample_defenses/base_inception_model/defense.py

import tensorflow.contrib.slim.nets as nets

#from tensorflow.contrib.slim.nets.inception import inception_v3, inception_v3_arg_scope

import numpy as np

import tensorflow as tf

slim = tf.contrib.slim

import pdb

height = 299

width = 299

channels = 3

num_classes=1001

X = tf.placeholder(tf.float32, shape=[None, height, width, channels])

y = tf.placeholder(tf.float32,shape=[None,182])

with slim.arg_scope(nets.inception.inception_v3_arg_scope()):

    logits, end_points = nets.inception.inception_v3_base(X,final_endpoint = "Mixed_7a")

    variables_to_restore=slim.get_variables_to_restore()

with tf.Session() as sess:

    sess.run(tf.global_variables_initializer())

    saver=tf.train.Saver(variables_to_restore)

    saver.restore(sess,"./inception_v3.ckpt")

    print("Done")

extract features utilizing inception-resnet-v2

import tensorflow as tf

slim = tf.contrib.slim

from inception_resnet_v2 import *

import pdb

from PIL import Image

import numpy as np

height = 299

width = 299

channels = 3

num_classes=1001

input_tensor = tf.placeholder(tf.float32, shape=[None, height, width, channels])

checkpoint_file = 'inception_resnet_v2_2016_08_30.ckpt'

sample_images = ['16548706_m.jpg']

#sample_images = ['dog.jpg', 'panda.jpg']

#Load the model

sess = tf.Session()

arg_scope = inception_resnet_v2_arg_scope()

with slim.arg_scope(arg_scope):

  logits, end_points = inception_resnet_v2(input_tensor, is_training=False)

saver = tf.train.Saver()

saver.restore(sess, checkpoint_file)

pdb.set_trace()

for image in sample_images:

  im = Image.open(image).resize((299,299))

  im = np.array(im)

  im = im.reshape(-1,299,299,3)

  #predict_values, logit_values = sess.run([end_points['Predictions'], logits], feed_dict={input_tensor: im})

  predict_values, logit_values = sess.run([end_points['Mixed_7a'], logits], feed_dict={input_tensor: im})

pytorch与tensorflow最大的差别:

pytorch通过pdb模块可以实现查看变量的值。

而tensorflow对于每个tf.Tensor必须要开sess.run(tf.variable)才能看其中的值。

recommendation framework常用的几个函数:

user_embeddings = tf.reshape(user_embeddings, [self.batch_size, 1, 1, self.dim])

user_relation_scores = tf.reduce_mean(user_embeddings * neighbor_relations, axis=-1)

user_relation_scores_normalized = tf.nn.softmax(user_relation_scores, dim=-1)

user_relation_scores_normalized = tf.expand_dims(user_relation_scores_normalized, axis=-1)

self.bias = tf.get_variable(shape=[self.dim], initializer=tf.zeros_initializer(), name='bias')

self.weights = tf.get_variable(shape=[self.dim, self.dim], initializer=tf.contrib.layers.xavier_initializer(), name='weights')

output = tf.nn.dropout(output, keep_prob=1-self.dropout)

output = tf.matmul(output, self.weights) + self.bias

output = tf.concat([self_vectors, neighbors_agg], axis=-1)

output = tf.nn.dropout(output, keep_prob=1-self.dropout)

embedded_contexts = tf.nn.embedding_lookup(self.context_embeddings, entities)

self.entity_embeddings = tf.Variable(entity_embs, dtype=np.float32, name='entity')

self.entity_embeddings = tf.layers.dense( self.entity_embeddings, units=args.entity_dim, activation=tf.nn.tanh, name='transformed_entity', kernel_regularizer=tf.contrib.layers.l2_regularizer(args.l2_weight))

self.labels = tf.placeholder( dtype=tf.float32, shape=[None], name='labels')

上一篇 下一篇

猜你喜欢

热点阅读