Pytorch 和Tensorflow之间的相互转化

2019-07-14  本文已影响0人  顾北向南

转载于:https://mp.weixin.qq.com/s/LzbJ_QuRqWJOj_ZRZ5P66A

1. 介绍

2.神奇的转换库 TfPyTh

3.TfPyTh 示例

import tensorflow as tf
import torch as th
import numpy as np
import tfpyth

session = tf.Session()
def get_torch_function():
    a = tf.placeholder(tf.float32, name='a')
    b = tf.placeholder(tf.float32, name='b')
    c = 3 * a + 4 * b * b

    f = tfpyth.torch_from_tensorflow(session, [a, b], c).apply
    return f

f = get_torch_function()
a = th.tensor(1, dtype=th.float32, requires_grad=True)
b = th.tensor(3, dtype=th.float32, requires_grad=True)
x = f(a, b)
assert x == 39.

x.backward()
assert np.allclose((a.grad, b.grad), (3., 24.))
上一篇 下一篇

猜你喜欢

热点阅读