2020-05-01 TensorFlow Example,Se

2020-05-01  本文已影响0人  破阵子沙场秋点兵

版本信息
python version=3.6.10
tensorflow version=1.15.0

import tensorflow as tf
import numpy as np
tf.enable_eager_execution()

数据编码

基本数据类型

Feature与FeatureList

# 基础数据
var_int = [1, 2, 3, 4, 5, 1, 2, 3, 4] # 整数
var_float = np.random.randn(2, 3).astype(np.float) # 浮点数
var_str = list(set('12')) # 字符串
var_len_int = [[1], [2] * 2]  # 变长Feature数据
var_seq_int = [[1, 2, 3], [4, 5]] # 不定长的序列
tf.train.Int64List(value=var_int)
tf.train.FloatList(value=var_float[0]) # value只接受列表,所有float只取1行数据
tf.train.BytesList(value=[bytes(s, encoding='utf8') for s in var_str])) # value只接受列表 列表的长度与下文解析函数FixedLenFeature的shape保持一致
tf.train.Int64List(value=var_len_int[0])
执行结果

Feature与FeatureLists

非序列数据使用Feature,序列数据使用FeatureList

int_feature =tf.train.Feature(int64_list=tf.train.Int64List(value=var_int))
tf.print(int_feature)
float_feature =tf.train.Feature(float_list=tf.train.FloatList(value=var_float[0]))
tf.print(float_feature)
str_feature =tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(s, encoding='utf8') for s in var_str]))
tf.print(str_feature)
len_feature =tf.train.Feature(int64_list=tf.train.Int64List(value=var_len_int[0]))
tf.print(len_feature)
# Feature的列表来组成FeatureList,与下文Features组成迥然不同
seq_feature =tf.train.FeatureList(
        feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=[v])) for v in var_seq_int[0]])
tf.print(seq_feature)
执行结果1.png
执行结果2.png

Features与FeatureLists

多个Feature组成Features

n_seq_feature = {
        "int_feature": int_feature,
        "float_feature": float_feature,
        "str_feature": str_feature,
        "len_feature": len_feature}
context_features = tf.train.Features(feature=n_seq_feature)
执行结果.png

多个FeatureList组成FeatureLists

sequence_feature = {
        "seq_feature": seq_feature
    }
sequence_features = tf.train.FeatureLists(feature_list=sequence_feature)
执行结果.png

Example与SequenceExample

example = tf.train.Example(features=context_features)

执行结果

features {
  feature {
    key: "float_feature"
    value {
      float_list {
        value: -0.34236639738082886
        value: 0.028749553486704826
        value: 0.5070503354072571
      }
    }
  }
  feature {
    key: "int_feature"
    value {
      int64_list {
        value: 1
        value: 2
        value: 3
        value: 4
        value: 5
        value: 1
        value: 2
        value: 3
        value: 4
      }
    }
  }
  feature {
    key: "len_feature"
    value {
      int64_list {
        value: 1
      }
    }
  }
  feature {
    key: "str_feature"
    value {
      bytes_list {
        value: "2"
        value: "1"
      }
    }
  }
}
sequence_example = tf.train.SequenceExample(context=context_features, feature_lists=sequence_features)

执行结果

context {
    feature {
        key: "float_feature"
        value {
            float_list {
                value: -0.34236639738082886
                value: 0.028749553486704826
                value: 0.5070503354072571
            }
        }
    }
    feature {
        key: "int_feature"
        value {
            int64_list {
                value: 1
                value: 2
                value: 3
                value: 4
                value: 5
                value: 1
                value: 2
                value: 3
                value: 4
            }
        }
    }
    feature {
        key: "len_feature"
        value {
            int64_list {
                value: 1
            }
        }
    }
    feature {
        key: "str_feature"
        value {
            bytes_list {
                value: "2"
                value: "1"
            }
        }
    }
}
feature_lists {
    feature_list {
        key: "seq_feature"
        value {
            feature {
                int64_list {
                    value: 1
                }
            }
            feature {
                int64_list {
                    value: 2
                }
            }
            feature {
                int64_list {
                    value: 3
                }
            }
        }
    }
}

序列化

serialized_example = example.SerializeToString()
serialized_sequence_example = sequence_example.SerializeToString()
序列化后的结果.png

数据解码

parse_single_example与parse_single_sequence_example

parse_single_example只能解析Example
parse_single_sequence_example既可以解析Example也可以解析SequenceExample

context_features = {
    'int_feature': tf.io.FixedLenFeature([9], dtype=tf.int64),
    'float_feature': tf.io.FixedLenFeature([3], dtype=tf.float32),
    'str_feature': tf.io.FixedLenFeature([2], dtype=tf.string),
    'len_feature': tf.io.VarLenFeature(dtype=tf.int64)
}
tf.io.parse_single_example(serialized=serialized_example, features=context_features)
执行结果.png
sequence_features = {
    'seq_feature': tf.io.FixedLenSequenceFeature([1], dtype=tf.int64)
}
tf.io.parse_single_sequence_example(serialized_sequence_example, context_features=context_features, sequence_features=sequence_features)
执行结果.png

parse_example 与parse_sequence_example

与parse_single_example和parse_single_sequence_example相比,前者可以处理batch数据,详细见下节

整理代码

为了讲解方便,完整代码与上述代码稍有出入,但不影响整体理解

import tensorflow as tf
import numpy as np

tf.enable_eager_execution()

var_int = [1, 2, 3, 4, 5, 1, 2, 3, 4]  # 整数
var_float = np.random.randn(2, 3).astype(np.float)  # 浮点数
var_str = list(set('12'))  # 字符串
var_len_int = [[1], [2] * 2]  # 变长Feature数据
var_seq_int = [[1, 2, 3], [4, 5]]  # 不定长的序列

serialized_examples = []
serialized_sequence_examples = []
for i in range(2):
    int_data = var_int
    float_data = var_float[i]
    str_data = var_str[i]
    len_data = var_len_int[i]
    seq_data = var_seq_int[i]
    int_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=int_data))
    float_feature = tf.train.Feature(float_list=tf.train.FloatList(value=float_data))
    str_feature = tf.train.Feature(bytes_list=tf.train.BytesList(value=[bytes(s, encoding='utf8') for s in str_data]))
    len_feature = tf.train.Feature(int64_list=tf.train.Int64List(value=len_data))
    seq_feature = tf.train.FeatureList(
        feature=[tf.train.Feature(int64_list=tf.train.Int64List(value=[v])) for v in seq_data])
    n_seq_features = {
        "int_feature": int_feature,
        "float_feature": float_feature,
        "str_feature": str_feature,
        "len_feature": len_feature}
    context_features = tf.train.Features(feature=n_seq_features)
    sequence_feature = {
        "seq_feature": seq_feature
    }
    sequence_features = tf.train.FeatureLists(feature_list=sequence_feature)
    example = tf.train.Example(features=context_features)
    sequence_example = tf.train.SequenceExample(context=context_features, feature_lists=sequence_features)
    # print(example)
    serialized_example = example.SerializeToString()
    serialized_sequence_example = sequence_example.SerializeToString()
    serialized_examples.append(serialized_example)
    serialized_sequence_examples.append(serialized_sequence_example)

context_features = {
    'int_feature': tf.io.FixedLenFeature([9], dtype=tf.int64),
    'float_feature': tf.io.FixedLenFeature([3], dtype=tf.float32),
    'str_feature': tf.io.FixedLenFeature([1], dtype=tf.string),
    'len_feature': tf.io.VarLenFeature(dtype=tf.int64)
}

context_parsed = tf.io.parse_example(serialized=serialized_examples, features=context_features)
tf.print('parse_example解析之后的结果')
tf.print('context_parsed', context_parsed, summarize=-1)

sequence_features = {
    'seq_feature': tf.io.FixedLenSequenceFeature([1], dtype=tf.int64)
}
context_parsed, sequence_parsed, _ = tf.io.parse_sequence_example(serialized=serialized_sequence_examples,
                                                                  context_features=context_features,
                                                                  sequence_features=sequence_features)
tf.print('parse_sequence_example解析之后的结果')

tf.print('context_parsed', context_parsed, summarize=-1)
tf.print('sequence_parsed', sequence_parsed, summarize=-1)

结果

parse_example解析之后的结果
context_parsed {'float_feature': [[0.911513448 -1.24205828 -0.316245258]
 [-0.255276382 0.210701108 1.45683241]],
 'int_feature': [[1 2 3 4 5 1 2 3 4]
 [1 2 3 4 5 1 2 3 4]],
 'len_feature': 'SparseTensor(indices=[[0 0]
 [1 0]
 [1 1]], values=[1 2 2], shape=[2 2])',
 'str_feature': [["1"]
 ["2"]]}
parse_sequence_example解析之后的结果
context_parsed {'float_feature': [[0.911513448 -1.24205828 -0.316245258]
 [-0.255276382 0.210701108 1.45683241]],
 'int_feature': [[1 2 3 4 5 1 2 3 4]
 [1 2 3 4 5 1 2 3 4]],
 'len_feature': 'SparseTensor(indices=[[0 0]
 [1 0]
 [1 1]], values=[1 2 2], shape=[2 2])',
 'str_feature': [["1"]
 ["2"]]}
sequence_parsed {'seq_feature': [[[1]
  [2]
  [3]]

 [[4]
  [5]
  [0]]]}

注意点

上一篇 下一篇

猜你喜欢

热点阅读