TensorFlow2.0 使用tf.function装饰器将动

2021-01-24  本文已影响0人  又双叒叕苟了一天

这里只是一些简单的介绍

tf.function简单使用

假设我们要把一个模型的前向传播转化成静态图:

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers

class model(keras.Model):
    def __init__(self):
        pass

    @tf.function
    def call(x):
        pass

这个装饰器对任何只包含tensor操作的函数都有效.

动态图和静态图转换时需要注意的区别

for, while

在eager执行模式下, 可以使用普通的python语法对流程进行控制, 但是在tf.function装饰的函数中, 要对上面的2种方式进行转化.

for

def funa():
    for i in range(10):
        pass

@tf.function
def funb():
    for i in tf.range(10):
        pass

while

def funa():
    i = 0
    while i < 10:
        i += 1
        pass

@tf.function
def funb():
    i = tf.constant(0)
    while i < tf.constant(10):
        i += 1
        pass

使用1.x的tf.cond, tf.while_loop的方式进行控制应该也是可以的.

print

在使用tf.function装饰的函数中print只会在最初执行1次, tf.Variable()也是. 如果要每次都执行需要使用tf.print

def funa():
    i = 0
    while i < 10:
        print(i)
        i += 1

@tf.function
def funb():
    i = tf.constant(0)
    while i < tf.constant(10):
        tf.print(i)
        i += 1

TensorArry

如果要使用类似python中类似list的数据结构, 可以使用tf.TensorArray

def funa():
    i = 0
    res = []
    while i < 10:
        print(i)
        res.append(i)
        i += 1

@tf.function
def funb():
    i = tf.constant(0)
    res = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size)
    while i < tf.constant(10):
        tf.print(i)
        res = res.write(i, i)  # 注意这个`=`, 如果只写res.write(i, i)会出错
        i += 1

input_signature

@tf.function是支持多态的, 假设有以下函数

@tf.function
def fun(x, y, training=True):
    return x + y

x=tf.constant(0)y=tf.constant(1), x=tf.constant(0.0)y=tf.constant(1.0)的情况下是会产生两个不同的静态图的, 甚至x=tf.constant(0)y=tf.constant(1), x=tf.constant(1)y=tf.constant(1) 都是两个不同的静态图, 因为他们的数据类型不同, 或者数值不同都会造成静态图不同, 这时候静态图可能比eager执行方式更加费时, 因为需要retracing是哪一张静态图. 所以在使用@tf.function时最好指定输入数据的类型和shape, 类似于tensorflow1.x中tf.placehold的效果:

@tf.function(input_signature=(tf.TensorSpec(shape=None, dtype=tf.int32), tf.TensorSpec(shape=None, dtype=tf.int32))
def fun(x, y, training=True):
    return x + y

此时输入x=tf.constant(0)y=tf.constant(1), x=tf.constant(1)y=tf.constant(1)都会调用同一张静态图. 另外, 传入的每一个python类型也都会构造一个图, 所以最好把training=True改为training=tf.constant(True).

shape

和tensorflow1.x中tf.shape于get_shape()/shape的区别类似, 在tf.function装饰的函数中, 需要使用tf.shape()获取tensor的shape, 而不能使用get_shape()或者shape. 否则会产生NoneType错误.

上一篇下一篇

猜你喜欢

热点阅读