Spark向量化读取Parquet文件源码

2021-08-17  本文已影响0人  阿猫阿狗Hakuna

原文:https://animeshtrivedi.github.io/spark-parquet-reading

Spark 如何读取Parquet文件

Apache Parquet 是一种流行的列式存储格式,它把数据存储为一堆文件。
Spark读取parquet依赖以下API:

val parquetFileDF = spark.read.parquet("test.parquet")

test.parquet文件格式为<int, Array[Byte]>。

关键对象

在 Spark SQL 中,各种操作都在各自的类中实现,其名称都以Exec作为后缀。

1.DataSourceScanExec类掌管的是对数据源的读取。读取Parquet文件的相关代码从这里开始,在ParquetFileFormat类中结束。

2.ParquetFileFormat中有一个buildReader函数,返回一个(PartitionedFile => Iterator[InternalRow])。此函数中生成了一个迭代器:

val iter = new RecordReaderIterator(parquetReader)

这里parquetReader是一个VectorizedParquetRecordReader。RecordReaderIterator包装了一个scala迭代器,以Hadoop RecordReader<K,V>风格。它由 VectorizedParquetRecordReader(及其基类 SpecificParquetRecordReaderBase<Object>)实现。

  1. VectorizedParquetRecordReader做了什么?根据文件中的comment:一个专门的RecordReader,直接使用Parquet column API 读入InternalRows或ColumnarBatches,基于parquet-mr的ColumnReader。VectorizedParquetRecordReader 对象分配后,调用initialize(split, hadoopAttemptContext)函数和initBatch(partitionSchema, file.partitionValues)函数。

4.VectorizedParquetRecordReader 中 RecordReader 接口的实现需要多加关注,它在使用步骤 2 中的迭代器时调用的是什么?在调用nextKeyValue()时,该函数首先调用了resultBatch(),然后调用nextBatch()。请记住,我们总是在Batch Mode下操作(returnColumnarBatch 设置为 true),nextBatch用数据填充columnarBatch,且这个变量会在getCurrentValue函数中返回。getCurrentKey 在 SpecificParquetRecordReaderBase 的基类中实现,且始终返回null。

现在,我们知道了迭代器中返回了什么变量。从这开始有两个方向,首先我们描述ColumnarBatch是怎么被Parquet数据填充。然后我们描述谁使用了步骤2中生成的iter迭代器。

ColumnarBatch 如何被填充?

在 VectorizedParquetRecordReader.nextBatch() 函数中,如果尚未读取所有行,则调用 checkEndOfRowGroup() 函数。然后,checkEndOfRowGroup 函数读取一个rowGroup(可以将rowGroup视为以列格式存储的一定数量行的集合),然后为requestedSchema 中的每个请求列分配一个VectorizedColumnReader 对象。VectorizedColumnReader 构造函数接受一个 ColumnDescriptor(可以在schema中找到)和一个 PageReader(可以从 rowGroup 中找到,一个 Parquet API 调用)。
另外,missingColumns是确实列的一个bitmap(可能是缺失的列或 Spark 不打算读取的列)。然后,在nextBatch中调用readBatch(num, columnarBatch.column(i)),会在之前checkEndOfRowGroup(基本上是每列)函数分配的所有VectorizedColumnReader对象上调用。(因此,ColumnarBatch 和 ColumnVector 只是 VectorizedColumnReader 使用的原始内存)。所以在 readBatch 中,传递了行数和 ColumnVector(存储在 ColumnarBatch 中)。什么是ColumnVector?我们可以将其视为一个类型数组,由 rowId 索引。

/**
 * An interface representing in-memory columnar data in Spark. This interface defines the main APIs
 * to access the data, as well as their batched versions. The batched versions are considered to be
 * faster and preferable whenever possible.
 *
 * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values
 * in this ColumnVector.
 *
 * Spark only calls specific `get` method according to the data type of this {@link ColumnVector},
 * e.g. if it's int type, Spark is guaranteed to only call {@link #getInt(int)} or
 * {@link #getInts(int, int)}.
 *
 * ColumnVector supports all the data types including nested types. To handle nested types,
 * ColumnVector can have children and is a tree structure. Please refer to {@link #getStruct(int)},
 * {@link #getArray(int)} and {@link #getMap(int)} for the details about how to implement nested
 * types.
 *
 * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating
 * memory again and again.
 *
 * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint.
 * Implementations should prefer computing efficiency over storage efficiency when design the
 * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage
 * footprint is negligible.
 */
@Evolving
public abstract class ColumnVector implements AutoCloseable {

总之,原始数据存储在 ColumnVector 中,ColumnVector 本身存储在 ColumnBatch 对象中。ColumnVector 是在 readBatch 函数中作为存储空间传递的。 在 readBatch 函数内部,它首先调用 readPage() 函数,该函数查看我们正在读取哪个版本的 parquet 文件(v1 或 v2,我不知道区别),然后初始化一堆对象,即 defColumn: VectorizedRleValuesReader、replicationLevelColumn:ValuesReaderIntIterator、definitionLevelColumn:ValuesReaderIntIterator 和 dataColumn:VectorizedRleValuesReader。这些变量中的 ValuesReaderIntIterator 来自 parquet-mr,而 VectorizedRleValuesReader 来自 Spark。接下来,有一堆 read[Type]Batch() 函数被调用,这些函数又调用 defColumn.read[Type]s() 函数。 (这里的 [Type] 是一些类型,如 Int、Short、Binary 等)。 在 VectorizedRleValuesReader 上的这些函数中,数据被读取、解码(可能来自 RLE),然后插入到此处传递的 ColumnVector 中。

Scala[ColumnBatch] 迭代器在哪里被消费?

迭代器根据 reader 是否处于批处理模式返回两种不同的类型,code如下:

  @Override
  public Object getCurrentValue() {
    if (returnColumnarBatch) return columnarBatch;
    return columnarBatch.getRow(batchIdx - 1);
  }

其中,columnarBatch的类型是ColumnarBatch,columnarBatch.getRow 返回一个 ColumnarBatch.Row 类型的嵌套类。这个迭代器以某种方式传递给wholestage code generation。消费这个迭代器并且实例化UnsafeRow的code示例如下:

/* 001 */ public Object generate(Object[] references) {
/* 002 */   return new GeneratedIterator(references);
/* 003 */ }
/* 004 */
/* 005 */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator {
/* 006 */   private Object[] references;
/* 007 */   private scala.collection.Iterator[] inputs;
/* 008 */   private scala.collection.Iterator scan_input;
/* 009 */   private org.apache.spark.sql.execution.metric.SQLMetric scan_numOutputRows;
/* 010 */   private org.apache.spark.sql.execution.metric.SQLMetric scan_scanTime;
/* 011 */   private long scan_scanTime1;
/* 012 */   private org.apache.spark.sql.execution.vectorized.ColumnarBatch scan_batch;
/* 013 */   private int scan_batchIdx;
/* 014 */   private org.apache.spark.sql.execution.vectorized.ColumnVector scan_colInstance0;
/* 015 */   private org.apache.spark.sql.execution.vectorized.ColumnVector scan_colInstance1;
/* 016 */   private UnsafeRow scan_result;
/* 017 */   private org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder scan_holder;
/* 018 */   private org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter scan_rowWriter;
/* 019 */
/* 020 */   public GeneratedIterator(Object[] references) {
/* 021 */     this.references = references;
/* 022 */   }
/* 023 */
/* 024 */   public void init(int index, scala.collection.Iterator[] inputs) {
/* 025 */     partitionIndex = index;
/* 026 */     this.inputs = inputs;
/* 027 */     scan_input = inputs[0];
/* 028 */     this.scan_numOutputRows = (org.apache.spark.sql.execution.metric.SQLMetric) references[0];
/* 029 */     this.scan_scanTime = (org.apache.spark.sql.execution.metric.SQLMetric) references[1];
/* 030 */     scan_scanTime1 = 0;
/* 031 */     scan_batch = null;
/* 032 */     scan_batchIdx = 0;
/* 033 */     scan_colInstance0 = null;
/* 034 */     scan_colInstance1 = null;
/* 035 */     scan_result = new UnsafeRow(2);
/* 036 */     this.scan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(scan_result, 32);
/* 037 */     this.scan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(scan_holder, 2);
/* 038 */
/* 039 */   }
/* 040 */
/* 041 */   private void scan_nextBatch() throws java.io.IOException {
/* 042 */     long getBatchStart = System.nanoTime();
/* 043 */     if (scan_input.hasNext()) {
/* 044 */       scan_batch = (org.apache.spark.sql.execution.vectorized.ColumnarBatch)scan_input.next();
/* 045 */       scan_numOutputRows.add(scan_batch.numRows());
/* 046 */       scan_batchIdx = 0;
/* 047 */       scan_colInstance0 = scan_batch.column(0);
/* 048 */       scan_colInstance1 = scan_batch.column(1);
/* 049 */
/* 050 */     }
/* 051 */     scan_scanTime1 += System.nanoTime() - getBatchStart;
/* 052 */   }
/* 053 */
/* 054 */   protected void processNext() throws java.io.IOException {
/* 055 */     if (scan_batch == null) {
/* 056 */       scan_nextBatch();
/* 057 */     }
/* 058 */     while (scan_batch != null) {
/* 059 */       int numRows = scan_batch.numRows();
/* 060 */       while (scan_batchIdx < numRows) {
/* 061 */         int scan_rowIdx = scan_batchIdx++;
/* 062 */         boolean scan_isNull = scan_colInstance0.isNullAt(scan_rowIdx);
/* 063 */         int scan_value = scan_isNull ? -1 : (scan_colInstance0.getInt(scan_rowIdx));
/* 064 */         boolean scan_isNull1 = scan_colInstance1.isNullAt(scan_rowIdx);
/* 065 */         byte[] scan_value1 = scan_isNull1 ? null : (scan_colInstance1.getBinary(scan_rowIdx));
/* 066 */         scan_holder.reset();
/* 067 */
/* 068 */         scan_rowWriter.zeroOutNullBytes();
/* 069 */
/* 070 */         if (scan_isNull) {
/* 071 */           scan_rowWriter.setNullAt(0);
/* 072 */         } else {
/* 073 */           scan_rowWriter.write(0, scan_value);
/* 074 */         }
/* 075 */
/* 076 */         if (scan_isNull1) {
/* 077 */           scan_rowWriter.setNullAt(1);
/* 078 */         } else {
/* 079 */           scan_rowWriter.write(1, scan_value1);
/* 080 */         }
/* 081 */         scan_result.setTotalSize(scan_holder.totalSize());
/* 082 */         append(scan_result);
/* 083 */         if (shouldStop()) return;
/* 084 */       }
/* 085 */       scan_batch = null;
/* 086 */       scan_nextBatch();
/* 087 */     }
/* 088 */     scan_scanTime.add(scan_scanTime1 / (1000 * 1000));
/* 089 */     scan_scanTime1 = 0;
/* 090 */   }
/* 091 */ }

在scan_nextBatch方法中,我们通过调用next()读取一个新的ColumnarBatch。然后我们获取ColumnVectors对象(变量 scan_colInstance0/scan_colInstance1)。通过numRows()方法,我们可以得到ColumnarBatch的行数,通过调用ColumnVector对象的get[Type](rowId: Int)获取最终的值。
这些值在BufferHolder和UnsafeRowWriter对象的帮助下表示为UnsafeRow:

/* 035 */     scan_result = new UnsafeRow(2);
/* 036 */     this.scan_holder = new org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder(scan_result, 32);
/* 037 */     this.scan_rowWriter = new org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter(scan_holder, 2);
上一篇 下一篇

猜你喜欢

热点阅读