EntityFrameworkCore批量插入(PgSQL篇)

2021-11-09  本文已影响0人  野生DBNull

写在前面

除了批量新增之前Z.EntityFramework.Plus都实现了,只有批量新增是收费的,所以这里只介绍如何实现批量新增。

思路详解

一般的数据库驱动都是带批量新增功能的,但是EntityFramework支持多种数据库,要根据业务去实现不同的数据库的批量插入,就只要找到数据库驱动的支持文档就大概可以知道用哪个函数可以实现批量插入的功能了。现在EfCore中的DbContext是可以拿到DbConnection,能拿到这个对象那就意味着我可以直接使用Ado.net去干这个事情,而且Ef还会帮我管理销毁和连接池等一系列的脏活。

PgSQL批量新增

PgSQL的.NET驱动一般是Npgsql,那么找到它的文档就应该可以找到对应的批量操作函数,然后按照文档大概就可以知道怎么做能搞定这个事情。

Npgsql官网

image.png
核心代码分析

思路就是先根据EF给到的EntityType构建列的映射,然后再将数据一个一个的塞DataTable中。
最后再使用BeginBinaryImport语法一次性插入

 var pgConnection = dbContext.Database.GetDbConnection() as NpgsqlConnection;
// 这里是构建Copy的SQL语句
 var commandFormat = string.Format("COPY \"{0}\"({1}) FROM STDIN BINARY", tableName, string.Join(",", fields));

 // 主要就是用的这函数 BeginBinaryImport
 using (var writer = pgConnection.BeginBinaryImport(commandFormat))
 {
     foreach (DataRow item in dataTable.Rows)
     {
         await writer.WriteRowAsync(cancellationToken, item.ItemArray); // 异步写入数据库
     }
     await writer.CompleteAsync(cancellationToken); // 写完之后一次性提交
 }

  
示例代码(鄙人框架是ABP的,所以这里是用的ABP框架作为示例代码)
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.Extensions.Logging;
using Npgsql;
using System;
using System.Collections;
using System.Collections.Generic;
using System.Data;
using System.Linq;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Volo.Abp.Data;
using Volo.Abp.Domain.Entities;
using Volo.Abp.Guids;
using Volo.Abp.ObjectExtending;
using Volo.Abp.DependencyInjection;
using Volo.Abp.EntityFrameworkCore;
using Volo.Abp.Uow;

namespace EntityFrameworkCore.BulkOperationProvider
{
    [ExposeServices(typeof(IEfCoreBulkOperationProvider))]
    public class PgsqlEfCoreBulkOperationProvider : IEfCoreBulkOperationProvider
    {
        protected ILogger<PgsqlEfCoreBulkOperationProvider> Logger { get; private set; }
        protected IGuidGenerator GuidGenerator { get; private set; }

        public PgsqlEfCoreBulkOperationProvider(ILogger<PgsqlEfCoreBulkOperationProvider> logger, IGuidGenerator guidGenerator)
        {
            Logger = logger;
            GuidGenerator = guidGenerator;
        }

        [UnitOfWork(IsDisabled = true)]
        public async Task<int> BulkInsertAsync<TDbContext, TEntity>(TDbContext dbContext, IEnumerable<TEntity> entities, CancellationToken cancellationToken = default)
            where TEntity : class, IEntity
            where TDbContext : IEfCoreDbContext
        {
            if (entities.Count() < 1) return 0;

            var dbSet = dbContext.Set<TEntity>();
            var entityType = dbSet.EntityType;
            var entityProps = entityType.GetProperties();

            var tableName = dbSet.EntityType.GetTableName();
            var storeObjectIdentifier = StoreObjectIdentifier.Table(tableName, dbSet.EntityType.GetSchema());

            var pgConnection = dbContext.Database.GetDbConnection() as NpgsqlConnection;
            if (pgConnection == null)
                throw new Exception("DbConnecion is not assignable to [NpgsqlConnection]");

            try
            {
                int curIndex = 0;
                int batchSize = 10000;
                int totalCount = entities.Count();

                var dataTable = new DataTable();
                var fields = new List<string>();
                var needHandleExtraProps = typeof(TEntity).IsAssignableTo<IHasExtraProperties>();

                // 构建字段与列头
                foreach (var item in entityProps)
                {
                    var colName = item.GetColumnName(storeObjectIdentifier);

                    //var s = item.GetTypeMapping().ClrType;

                    var propertyType = item.PropertyInfo.PropertyType;
                    if (needHandleExtraProps && item.Name == nameof(IHasExtraProperties.ExtraProperties))
                        propertyType = typeof(string);

                    var typeMapping = Nullable.GetUnderlyingType(propertyType) ?? propertyType;

                    fields.Add($"\"{colName}\""); // 构建字段
                    dataTable.Columns.Add(new DataColumn(colName, typeMapping)); // 构建DataTable的列
                }

                // 构建导入SQL
                var commandFormat = string.Format("COPY \"{0}\"({1}) FROM STDIN BINARY", tableName, string.Join(",", fields));
                while (curIndex < totalCount)
                {
                    dataTable.Clear(); // 每次搞完一批之后都要清空DataTable,否则会报错

                    var batchEntities = entities.Skip(curIndex).Take(batchSize);
                    foreach (var item in batchEntities)
                    {
                        CheckAndSetId(item); // 为Guid赋值
                        ArrayList tempList = new ArrayList();
                        foreach (var entityProp in entityProps)
                        {
                            object obj = entityProp.PropertyInfo.GetValue(item, null);
                            if (needHandleExtraProps && entityProp.PropertyInfo.Name == nameof(IHasExtraProperties.ExtraProperties))
                                obj = SerializeExtraObject((item as IHasExtraProperties).ExtraProperties, typeof(TEntity));

                            tempList.Add(obj);
                        }
                        dataTable.LoadDataRow(tempList.ToArray(), true);
                    }

                    using (var writer = pgConnection.BeginBinaryImport(commandFormat))
                    {
                        foreach (DataRow item in dataTable.Rows)
                        {
                            await writer.WriteRowAsync(cancellationToken, item.ItemArray);
                        }
                        await writer.CompleteAsync(cancellationToken);
                    }
                    curIndex += batchSize;
                }

            }
            catch (Exception ex)
            {
                Logger.LogError(ex, $"PG批量插入出错,Error->{ex.Message}");
                throw ex;
            }

            return entities.Count();
        }

        protected virtual void CheckAndSetId<TEntity>(TEntity entity)
        {
            if (entity is IEntity<Guid> entityWithGuidId)
            {
                TrySetGuidId(entityWithGuidId);
            }
        }

        protected virtual void TrySetGuidId(IEntity<Guid> entity)
        {
            if (entity.Id != default)
            {
                return;
            }

            EntityHelper.TrySetId(
                entity,
                () => GuidGenerator.Create(),
                true
            );
        }

        protected virtual string SerializeExtraObject(ExtraPropertyDictionary extraProperties, Type entityType)
        {
            var copyDictionary = new Dictionary<string, object>(extraProperties);

            if (entityType != null)
            {
                var objectExtension = ObjectExtensionManager.Instance.GetOrNull(entityType);
                if (objectExtension != null)
                {
                    foreach (var property in objectExtension.GetProperties())
                    {
                        if (property.IsMappedToFieldForEfCore())
                        {
                            copyDictionary.Remove(property.Name);
                        }
                    }
                }
            }

            return JsonSerializer.Serialize(copyDictionary);
        }
    }
}


上一篇 下一篇

猜你喜欢

热点阅读