TensorFlowOnSpark 接口函数用法
2019-01-07 本文已影响26人
PPT讲解: https://www.matroid.com/scaledml/2017/andy.pdf
样例代码: https://github.com/yahoo/TensorFlowOnSpark/blob/master/examples/mnist/streaming/mnist_spark.py
# Copyright 2017 Yahoo Inc.
# Licensed under the terms of the Apache 2.0 license.
# Please see LICENSE file in the project root for terms.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pyspark.context import SparkContext
from pyspark.conf import SparkConf
from pyspark.streaming import StreamingContext
import argparse
import numpy
from datetime import datetime
from tensorflowonspark import TFCluster
import mnist_dist
sc = SparkContext(conf=SparkConf().setAppName("mnist_streaming"))
ssc = StreamingContext(sc, 60)
executors = sc._conf.get("spark.executor.instances")
num_executors = int(executors) if executors is not None else 1
num_ps = 1
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--format", help="example format: (csv|csv2|pickle|tfr)", choices=["csv", "csv2", "pickle", "tfr"], default="stream")
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--rdma", help="use rdma connection", default=False)
args = parser.parse_args()
print("args:", args)
print("{0} ===== Start".format(datetime.now().isoformat()))
def parse(ln):
lbl, img = ln.split('|')
image = [int(x) for x in img.split(',')]
label = numpy.zeros(10)
label[int(lbl)] = 1.0
return (image, label)
stream = ssc.textFileStream(args.images)
imageRDD = stream.map(lambda ln: parse(ln))
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.SPARK)
if args.mode == "train":
labelRDD = cluster.inference(imageRDD)
print("{0} ===== Stop".format(datetime.now().isoformat()))

from tensorflowonspark import TFCluster
import .minist_dist.map_fun
sc = ...
args = ...
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, args.cluster_size, num_ps, args.tensorboard, TFCluster.InputMode.SPARK)
if args.mode == "train":
labelRDD = cluster.inference(imageRDD)
// sc = SparkContext(conf=SparkConf().setAppName("mnist_streaming"))
// ssc = StreamingContext(sc, 60)
// executors = sc._conf.get("spark.executor.instances")
// num_executors = int(executors) if executors is not None else 1
// num_ps = 1
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", help="number of records per batch", type=int, default=100)
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--format", help="example format: (csv|csv2|pickle|tfr)", choices=["csv", "csv2", "pickle", "tfr"], default="stream")
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--model", help="HDFS path to save/load model during train/inference", default="mnist_model")
parser.add_argument("--cluster_size", help="number of nodes in the cluster", type=int, default=num_executors)
parser.add_argument("--output", help="HDFS path to save test/inference output", default="predictions")
parser.add_argument("--steps", help="maximum number of steps", type=int, default=1000)
parser.add_argument("--tensorboard", help="launch tensorboard process", action="store_true")
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--rdma", help="use rdma connection", default=False)
args = parser.parse_args()
def parse(ln):
lbl, img = ln.split('|')
image = [int(x) for x in img.split(',')]
label = numpy.zeros(10)
label[int(lbl)] = 1.0
return (image, label)
stream = ssc.textFileStream(args.images)
imageRDD = stream.map(lambda ln: parse(ln))
// print("{0} ===== Stop".format(datetime.now().isoformat()))
def map_fun(args, ctx):
from tensorflowonspark import TFNode
from datetime import datetime
import math
import numpy
import tensorflow as tf
import time
worker_num = ctx.worker_num
job_name = ctx.job_name
task_index = ctx.task_index
# Delay PS nodes a bit, since workers seem to reserve GPUs more quickly/reliably (w/o conflict)
if job_name == "ps":
time.sleep((worker_num + 1) * 5)
# Parameters
hidden_units = 128
batch_size = args.batch_size
# Get TF cluster and server instances
cluster, server = TFNode.start_cluster_server(ctx, 1, args.rdma)
// ...
if job_name == "ps":
elif job_name == "worker":
# Assigns ops to the local worker by default.
with tf.device(tf.train.replica_device_setter(
worker_device="/job:worker/task:%d" % task_index,
# Variables of the hidden layer
hid_w = tf.Variable(tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, hidden_units],
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([hidden_units]), name="hid_b")
tf.summary.histogram("hidden_weights", hid_w)
# Variables of the softmax layer
sm_w = tf.Variable(tf.truncated_normal([hidden_units, 10],
stddev=1.0 / math.sqrt(hidden_units)), name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
tf.summary.histogram("softmax_weights", sm_w)
# Placeholders or QueueRunner/Readers for input data
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS], name="x")
y_ = tf.placeholder(tf.float32, [None, 10], name="y_")
x_img = tf.reshape(x, [-1, IMAGE_PIXELS, IMAGE_PIXELS, 1])
tf.summary.image("x_img", x_img)
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
global_step = tf.Variable(0)
loss = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
tf.summary.scalar("loss", loss)
train_op = tf.train.AdagradOptimizer(0.01).minimize(
loss, global_step=global_step)
# Test trained model
label = tf.argmax(y_, 1, name="label")
prediction = tf.argmax(y, 1, name="prediction")
correct_prediction = tf.equal(prediction, label)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name="accuracy")
tf.summary.scalar("acc", accuracy)
saver = tf.train.Saver()
summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()
# Create a "supervisor", which oversees the training process and stores model state into HDFS
logdir = TFNode.hdfs_path(ctx, args.model)
print("tensorflow model path: {0}".format(logdir))
summary_writer = tf.summary.FileWriter("tensorboard_%d" % worker_num, graph=tf.get_default_graph())
if args.mode == "train":
sv = tf.train.Supervisor(is_chief=(task_index == 0),
sv = tf.train.Supervisor(is_chief=(task_index == 0),
# The supervisor takes care of session initialization, restoring from
# a checkpoint, and closing when done or an error occurs.
with sv.managed_session(server.target) as sess:
print("{0} session ready".format(datetime.now().isoformat()))
# Loop until the supervisor shuts down or 1000000 steps have completed.
step = 0
tf_feed = TFNode.DataFeed(ctx.mgr, args.mode == "train")
while not sv.should_stop() and not tf_feed.should_stop() and step < args.steps:
# Run a training step asynchronously.
# See `tf.train.SyncReplicasOptimizer` for additional details on how to
# perform *synchronous* training.
# using feed_dict
batch_xs, batch_ys = feed_dict(tf_feed.next_batch(batch_size))
feed = {x: batch_xs, y_: batch_ys}
if len(batch_xs) > 0:
if args.mode == "train":
_, summary, step = sess.run([train_op, summary_op, global_step], feed_dict=feed)
# print accuracy and save model checkpoint to HDFS every 100 steps
if (step % 100 == 0):
print("{0} step: {1} accuracy: {2}".format(datetime.now().isoformat(), step, sess.run(accuracy, {x: batch_xs, y_: batch_ys})))
if sv.is_chief:
summary_writer.add_summary(summary, step)
else: # args.mode == "inference"
labels, preds, acc = sess.run([label, prediction, accuracy], feed_dict=feed)
results = ["{0} Label: {1}, Prediction: {2}".format(datetime.now().isoformat(), l, p) for l, p in zip(labels, preds)]
print("acc: {0}".format(acc))
if sv.should_stop() or step >= args.steps:
# Ask for all the services to stop.
print("{0} stopping supervisor".format(datetime.now().isoformat()))