Spring Boot集成Elasticsearch自定义序列化
2018-10-18 本文已影响0人
阿拉喵_d271
spring-boot-elasticsearch对搜索结果反序列化时不能获取_score
值,业务中又必须用所以就改一下框架里的序列化代码
- 继承AbstractResultMapper增加读取/写入
_score
的代码,我是使用@Score
在实体里标记写到哪个字段实现的
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* @author weizhiwen
* @date 2018/7/4
*/
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD, ElementType.METHOD, ElementType.ANNOTATION_TYPE})
public @interface Score {
}
import com.fasterxml.jackson.core.JsonEncoding;
import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonGenerator;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang.StringUtils;
import org.apache.commons.text.WordUtils;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetItemResponse;
import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHitField;
import org.springframework.data.domain.Pageable;
import org.springframework.data.elasticsearch.ElasticsearchException;
import org.springframework.data.elasticsearch.annotations.Document;
import org.springframework.data.elasticsearch.annotations.ScriptedField;
import org.springframework.data.elasticsearch.core.AbstractResultMapper;
import org.springframework.data.elasticsearch.core.DefaultEntityMapper;
import org.springframework.data.elasticsearch.core.EntityMapper;
import org.springframework.data.elasticsearch.core.aggregation.AggregatedPage;
import org.springframework.data.elasticsearch.core.aggregation.impl.AggregatedPageImpl;
import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentEntity;
import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.util.Assert;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.nio.charset.Charset;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
* @author weizhiwen
* @date 2018/10/18
*/
@Slf4j
public class ZwResultMapper extends AbstractResultMapper {
/**
* score set 方法缓存
*/
private static final Map<Class, List<Method>> SCORE_SET_METHOD_CACHE = new ConcurrentHashMap<>();
public static final String SET = "set";
private MappingContext<? extends ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty> mappingContext;
public ZwResultMapper() {
super(new DefaultEntityMapper());
}
public ZwResultMapper(MappingContext<? extends ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty> mappingContext) {
super(new DefaultEntityMapper());
this.mappingContext = mappingContext;
}
public ZwResultMapper(EntityMapper entityMapper) {
super(entityMapper);
}
public ZwResultMapper(
MappingContext<? extends ElasticsearchPersistentEntity<?>, ElasticsearchPersistentProperty> mappingContext,
EntityMapper entityMapper) {
super(entityMapper);
this.mappingContext = mappingContext;
}
@Override
public <T> AggregatedPage<T> mapResults(SearchResponse response, Class<T> clazz, Pageable pageable) {
long totalHits = response.getHits().getTotalHits();
List<T> results = new ArrayList<>();
for (SearchHit hit : response.getHits()) {
if (hit != null) {
T result = null;
if (StringUtils.isNotBlank(hit.getSourceAsString())) {
result = mapEntity(hit.getSourceAsString(), clazz);
} else {
result = mapEntity(hit.getFields().values(), clazz);
}
//增加的代码
setEntityScore(result, hit.getScore(), clazz);
setPersistentEntityId(result, hit.getId(), clazz);
setPersistentEntityVersion(result, hit.getVersion(), clazz);
populateScriptFields(result, hit);
results.add(result);
}
}
return new AggregatedPageImpl<>(results, pageable, totalHits);
}
private <T> void populateScriptFields(T result, SearchHit hit) {
if (hit.getFields() != null && !hit.getFields().isEmpty() && result != null) {
for (java.lang.reflect.Field field : result.getClass().getDeclaredFields()) {
ScriptedField scriptedField = field.getAnnotation(ScriptedField.class);
if (scriptedField != null) {
String name = scriptedField.name().isEmpty() ? field.getName() : scriptedField.name();
SearchHitField searchHitField = hit.getFields().get(name);
if (searchHitField != null) {
field.setAccessible(true);
try {
field.set(result, searchHitField.getValue());
} catch (IllegalArgumentException e) {
throw new ElasticsearchException("failed to set scripted field: " + name + " with value: "
+ searchHitField.getValue(), e);
} catch (IllegalAccessException e) {
throw new ElasticsearchException("failed to access scripted field: " + name, e);
}
}
}
}
}
}
private <T> T mapEntity(Collection<SearchHitField> values, Class<T> clazz) {
return mapEntity(buildJSONFromFields(values), clazz);
}
private String buildJSONFromFields(Collection<SearchHitField> values) {
JsonFactory nodeFactory = new JsonFactory();
try {
ByteArrayOutputStream stream = new ByteArrayOutputStream();
JsonGenerator generator = nodeFactory.createGenerator(stream, JsonEncoding.UTF8);
generator.writeStartObject();
for (SearchHitField value : values) {
if (value.getValues().size() > 1) {
generator.writeArrayFieldStart(value.getName());
for (Object val : value.getValues()) {
generator.writeObject(val);
}
generator.writeEndArray();
} else {
generator.writeObjectField(value.getName(), value.getValue());
}
}
generator.writeEndObject();
generator.flush();
return new String(stream.toByteArray(), Charset.forName("UTF-8"));
} catch (IOException e) {
return null;
}
}
@Override
public <T> T mapResult(GetResponse response, Class<T> clazz) {
T result = mapEntity(response.getSourceAsString(), clazz);
if (result != null) {
setPersistentEntityId(result, response.getId(), clazz);
setPersistentEntityVersion(result, response.getVersion(), clazz);
}
return result;
}
@Override
public <T> LinkedList<T> mapResults(MultiGetResponse responses, Class<T> clazz) {
LinkedList<T> list = new LinkedList<>();
for (MultiGetItemResponse response : responses.getResponses()) {
if (!response.isFailed() && response.getResponse().isExists()) {
T result = mapEntity(response.getResponse().getSourceAsString(), clazz);
setPersistentEntityId(result, response.getResponse().getId(), clazz);
setPersistentEntityVersion(result, response.getResponse().getVersion(), clazz);
list.add(result);
}
}
return list;
}
private <T> void setPersistentEntityId(T result, String id, Class<T> clazz) {
if (mappingContext != null && clazz.isAnnotationPresent(Document.class)) {
ElasticsearchPersistentEntity<?> persistentEntity = mappingContext.getPersistentEntity(clazz);
ElasticsearchPersistentProperty idProperty = persistentEntity.getIdProperty();
// Only deal with String because ES generated Ids are strings !
if (idProperty != null && idProperty.getType().isAssignableFrom(String.class)) {
persistentEntity.getPropertyAccessor(result).setProperty(idProperty, id);
}
}
}
private <T> void setEntityScore(T result, float score, Class<T> clazz) {
if (result != null && clazz.isAnnotationPresent(Document.class)) {
List<Method> methods;
if (SCORE_SET_METHOD_CACHE.containsKey(clazz)) {
methods = SCORE_SET_METHOD_CACHE.get(clazz);
} else {
methods = new Vector<>();
for (Method method : clazz.getMethods()) {
if (StringUtils.startsWith(method.getName(), SET) && hasAnnotation(clazz, method)) {
if (method.getParameterCount() == 1) {
if (Float.class.equals(method.getParameterTypes()[0])) {
methods.add(method);
}
}
}
}
SCORE_SET_METHOD_CACHE.put(clazz, methods);
}
for (Method method : methods) {
try {
method.invoke(result, score);
} catch (Exception e) {
log.error("{} set score error", clazz.getSimpleName());
}
}
}
}
public boolean hasAnnotation(Class<?> clazz, Method method) {
if (method.getAnnotation(Score.class) != null) {
return true;
}
String findFieldName = StringUtils.removeStart(method.getName(), "set");
for (Field field : clazz.getDeclaredFields()) {
if (field.getName().equalsIgnoreCase(findFieldName) && field.getAnnotation(Score.class) != null) {
return true;
}
}
return false;
}
private <T> void setPersistentEntityVersion(T result, long version, Class<T> clazz) {
if (mappingContext != null && clazz.isAnnotationPresent(Document.class)) {
ElasticsearchPersistentEntity<?> persistentEntity = mappingContext.getPersistentEntity(clazz);
ElasticsearchPersistentProperty versionProperty = persistentEntity.getVersionProperty();
// Only deal with Long because ES versions are longs !
if (versionProperty != null && versionProperty.getType().isAssignableFrom(Long.class)) {
// check that a version was actually returned in the response, -1 would indicate that
// a search didn't request the version ids in the response, which would be an issue
Assert.isTrue(version != -1, "Version in response is -1");
persistentEntity.getPropertyAccessor(result).setProperty(versionProperty, version);
}
}
}
}
- 定制ElasticsearchTemplate,把
org.springframework.data.elasticsearch.core.DefaultResultMapper
改成ZwResultMapper
import org.elasticsearch.client.Client;
import org.springframework.data.elasticsearch.core.ElasticsearchTemplate;
import org.springframework.data.elasticsearch.core.EntityMapper;
import org.springframework.data.elasticsearch.core.ResultsMapper;
import org.springframework.data.elasticsearch.core.convert.ElasticsearchConverter;
/**
* @author weizhiwen
* @date 2018/7/4
*/
public class ZwElasticsearchTemplate extends ElasticsearchTemplate {
public ZwElasticsearchTemplate(Client client) {
super(client);
}
public ZwElasticsearchTemplate(Client client, EntityMapper entityMapper) {
super(client, entityMapper);
}
public ZwElasticsearchTemplate(Client client, ElasticsearchConverter elasticsearchConverter, EntityMapper entityMapper) {
this(client, elasticsearchConverter,
new ZwResultMapper(elasticsearchConverter.getMappingContext(), entityMapper));
}
public ZwElasticsearchTemplate(Client client, ResultsMapper resultsMapper) {
super(client, resultsMapper);
}
public ZwElasticsearchTemplate(Client client, ElasticsearchConverter elasticsearchConverter) {
this(client, elasticsearchConverter, new ZwResultMapper(elasticsearchConverter.getMappingContext()));
}
public ZwElasticsearchTemplate(Client client, ElasticsearchConverter elasticsearchConverter, ResultsMapper resultsMapper) {
super(client, elasticsearchConverter, resultsMapper);
}
}
- 注册
ZwElasticsearchTemplate
@Bean
public ElasticsearchTemplate elasticsearchTemplate(Client client, ElasticsearchConverter converter) {
try {
return new ZwElasticsearchTemplate(client, converter);
} catch (Exception var4) {
throw new IllegalStateException(var4);
}
}
使用
类加上@Document
字段加上@Score
并且类型为Float
import lombok.Data;
import org.springframework.data.elasticsearch.annotations.Document;
/**
* @author weizhiwen
* @date 2018/10/18
*/
@Data
@Document(indexName = "asd")
public class AudioDocumentDTO extends BaseAudioDocument {
@Score
private Float score;
}