Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package io.hypersistence.utils.hibernate.query;

import io.hypersistence.utils.hibernate.util.ReflectionUtils;
import jakarta.persistence.Parameter;
import jakarta.persistence.Query;
import org.hibernate.query.spi.DomainQueryExecutionContext;
import org.hibernate.query.spi.QueryImplementor;
import org.hibernate.query.spi.QueryInterpretationCache;
import org.hibernate.query.spi.QueryParameterBindings;
import org.hibernate.query.spi.SelectQueryPlan;
import org.hibernate.query.sqm.internal.ConcreteSqmSelectQueryPlan;
import org.hibernate.query.sqm.internal.DomainParameterXref;
Expand All @@ -13,7 +15,14 @@
import org.hibernate.query.sqm.tree.select.SqmSelectStatement;
import org.hibernate.sql.exec.spi.JdbcOperationQuerySelect;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Proxy;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* The {@link SQLExtractor} allows you to extract the
Expand All @@ -36,42 +45,96 @@ protected SQLExtractor() {
* @return the underlying SQL generated by the provided JPA query
*/
public static String from(Query query) {
if(query instanceof SqmInterpretationsKey.InterpretationsKeySource &&
query instanceof QueryImplementor &&
query instanceof QuerySqmImpl) {
QueryInterpretationCache.Key cacheKey = SqmInterpretationsKey.createInterpretationsKey((SqmInterpretationsKey.InterpretationsKeySource) query);
QuerySqmImpl querySqm = (QuerySqmImpl) query;
Supplier buildSelectQueryPlan = () -> ReflectionUtils.invokeMethod(querySqm, "buildSelectQueryPlan");
SelectQueryPlan plan = cacheKey != null ? ((QueryImplementor) query).getSession().getFactory().getQueryEngine()
.getInterpretationCache()
.resolveSelectQueryPlan(cacheKey, buildSelectQueryPlan) :
(SelectQueryPlan) buildSelectQueryPlan.get();
if(plan instanceof ConcreteSqmSelectQueryPlan) {
ConcreteSqmSelectQueryPlan selectQueryPlan = (ConcreteSqmSelectQueryPlan) plan;
Object cacheableSqmInterpretation = ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "cacheableSqmInterpretation");
if(cacheableSqmInterpretation == null) {
DomainQueryExecutionContext domainQueryExecutionContext = DomainQueryExecutionContext.class.cast(querySqm);
cacheableSqmInterpretation = ReflectionUtils.invokeStaticMethod(
ReflectionUtils.getMethod(
ConcreteSqmSelectQueryPlan.class,
"buildCacheableSqmInterpretation",
SqmSelectStatement.class,
DomainParameterXref.class,
DomainQueryExecutionContext.class
),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "sqm"),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "domainParameterXref"),
domainQueryExecutionContext
);
}
if (cacheableSqmInterpretation != null) {
JdbcOperationQuerySelect jdbcSelect = ReflectionUtils.getFieldValueOrNull(cacheableSqmInterpretation, "jdbcSelect");
if (jdbcSelect != null) {
return jdbcSelect.getSqlString();
}
return getSqmQueryOptional(query)
.map(SQLExtractor::getSQLFromSqmQuery)
.orElseGet(() -> ReflectionUtils.invokeMethod(query, "getQueryString"));
}

private static String getSQLFromSqmQuery(QuerySqmImpl<?> querySqm) {
QueryInterpretationCache.Key cacheKey = SqmInterpretationsKey.createInterpretationsKey(querySqm);
Supplier<SelectQueryPlan<Object>> buildSelectQueryPlan = () -> ReflectionUtils.invokeMethod(querySqm, "buildSelectQueryPlan");
SelectQueryPlan<Object> plan = cacheKey != null ? ((QueryImplementor<?>) querySqm).getSession().getFactory().getQueryEngine()
.getInterpretationCache()
.resolveSelectQueryPlan(cacheKey, buildSelectQueryPlan) :
buildSelectQueryPlan.get();
if (plan instanceof ConcreteSqmSelectQueryPlan) {
ConcreteSqmSelectQueryPlan<?> selectQueryPlan = (ConcreteSqmSelectQueryPlan<?>) plan;
Object cacheableSqmInterpretation = ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "cacheableSqmInterpretation");
if (cacheableSqmInterpretation == null) {
cacheableSqmInterpretation = ReflectionUtils.invokeStaticMethod(
ReflectionUtils.getMethod(
ConcreteSqmSelectQueryPlan.class,
"buildCacheableSqmInterpretation",
SqmSelectStatement.class,
DomainParameterXref.class,
DomainQueryExecutionContext.class
),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "sqm"),
ReflectionUtils.getFieldValueOrNull(selectQueryPlan, "domainParameterXref"),
querySqm
);
}
if (cacheableSqmInterpretation != null) {
JdbcOperationQuerySelect jdbcSelect = ReflectionUtils.getFieldValueOrNull(cacheableSqmInterpretation, "jdbcSelect");
if (jdbcSelect != null) {
return jdbcSelect.getSqlString();
}
}
}
return ReflectionUtils.invokeMethod(query, "getQueryString");
return querySqm.getQueryString();
}

public static List<Object> getSQLParameterValues(Query query) {
return getSqmQueryOptional(query)
.map(SQLExtractor::getParametersFromInternalQuerySqm)
.orElseGet(() -> getSQLParametersFromJPAQuery(query));
}

/**
* Retrieves the parameters from the internal query SQM.
*
* @param querySqm the internal query SQM object
* @return a list of parameter values
*/
private static List<Object> getParametersFromInternalQuerySqm(QuerySqmImpl<?> querySqm) {
List<Object> parameterValues = new ArrayList<>();

QueryParameterBindings parameterBindings = querySqm.getParameterBindings();
parameterBindings.visitBindings((queryParameterImplementor, queryParameterBinding) -> {
Object value = queryParameterBinding.getBindValue();
parameterValues.add(value);
});

return parameterValues;
}

/**
* Get parameters from JPA query without any magic or Hibernate implementation tricks. Order is probably lost in current Hibernate versions.
*
* @param query
* @return
*/
private static List<Object> getSQLParametersFromJPAQuery(Query query) {
return query.getParameters()
.stream()
.map(Parameter::getPosition)
.map(query::getParameter)
.collect(Collectors.toList());
}


/**
* Get the unproxied hibernate query underlying the provided query object.
*
* @param query JPA query
* @return the unproxied Hibernate query, or original query if there is no proxy, or null if it's not an Hibernate query of required type
*/
private static Optional<QuerySqmImpl<?>> getSqmQueryOptional(Query query) {
Query unwrappedQuery = query.unwrap(Query.class);
if (unwrappedQuery instanceof QuerySqmImpl) {
QuerySqmImpl<?> querySqm = (QuerySqmImpl<?>) unwrappedQuery;
return Optional.of(querySqm);
}
return Optional.empty();
}
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,29 @@
package io.hypersistence.utils.hibernate.query;

import io.hypersistence.utils.hibernate.util.AbstractPostgreSQLIntegrationTest;
import jakarta.persistence.*;
import jakarta.persistence.Column;
import jakarta.persistence.Entity;
import jakarta.persistence.EntityManager;
import jakarta.persistence.FetchType;
import jakarta.persistence.Id;
import jakarta.persistence.ManyToOne;
import jakarta.persistence.Query;
import jakarta.persistence.Table;
import jakarta.persistence.Tuple;
import jakarta.persistence.TypedQuery;
import jakarta.persistence.criteria.CriteriaBuilder;
import jakarta.persistence.criteria.CriteriaQuery;
import jakarta.persistence.criteria.Join;
import jakarta.persistence.criteria.Root;
import org.junit.Test;

import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.time.LocalDate;
import java.util.List;

import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;

/**
Expand All @@ -28,15 +42,7 @@ protected Class<?>[] entities() {
@Test
public void testJPQL() {
doInJPA(entityManager -> {
Query jpql = entityManager
.createQuery(
"select " +
" YEAR(p.createdOn) as year, " +
" count(p) as postCount " +
"from " +
" Post p " +
"group by " +
" YEAR(p.createdOn)", Tuple.class);
Query jpql = createTestJPQL(entityManager);

String sql = SQLExtractor.from(jpql);

Expand All @@ -49,28 +55,30 @@ public void testJPQL() {
);
});
}

@Test
public void testCriteriaAPI() {
doInJPA(entityManager -> {
CriteriaBuilder builder = entityManager.getCriteriaBuilder();

CriteriaQuery<PostComment> criteria = builder.createQuery(PostComment.class);
Query criteriaQuery = createTestCriteriaQuery(entityManager);

Root<PostComment> postComment = criteria.from(PostComment.class);
Join<PostComment, Post> post = postComment.join("post");
String sql = SQLExtractor.from(criteriaQuery);

criteria.where(
builder.like(post.get("title"), "%Java%")
);
assertNotNull(sql);

criteria.orderBy(
builder.asc(postComment.get("id"))
LOGGER.info(
"The Criteria API query: [\n{}\n]\ngenerates the following SQL query: [\n{}\n]",
criteriaQuery.unwrap(org.hibernate.query.Query.class).getQueryString(),
sql
);
});
}

Query criteriaQuery = entityManager.createQuery(criteria);
@Test
public void testCriteriaAPIWithProxy() {
doInJPA(entityManager -> {
Query criteriaQuery = createTestCriteriaQuery(entityManager);
Query proxiedQuery = proxy(criteriaQuery);

String sql = SQLExtractor.from(criteriaQuery);
String sql = SQLExtractor.from(proxiedQuery);

assertNotNull(sql);

Expand All @@ -82,6 +90,77 @@ public void testCriteriaAPI() {
});
}

@Test
public void testJPQLGetSQLParameters() {
doInJPA(entityManager -> {
Query jpql = createTestJPQL(entityManager);

List<?> parameters = SQLExtractor.getSQLParameterValues(jpql);

assertFalse(parameters.isEmpty());

LOGGER.info(
"The Criteria API query: [\n{}\n]\nhas following SQL parameters: \n{}\n",
jpql.unwrap(org.hibernate.query.Query.class).getQueryString(),
parameters
);
});
}

@Test
public void testCriteriaGetSQLParameters() {
doInJPA(entityManager -> {
Query criteriaQuery = createTestCriteriaQuery(entityManager);

List<?> parameters = SQLExtractor.getSQLParameterValues(criteriaQuery);

assertFalse(parameters.isEmpty());

LOGGER.info(
"The Criteria API query: [\n{}\n]\nhas following SQL parameters: \n{}\n",
criteriaQuery.unwrap(org.hibernate.query.Query.class).getQueryString(),
parameters
);
});
}

private static Query proxy(Query criteriaQuery) {
return (Query) Proxy.newProxyInstance(Query.class.getClassLoader(), new Class[]{Query.class}, new HibernateLikeInvocationHandler(criteriaQuery));
}

private static Query createTestJPQL(EntityManager entityManager) {
Query jpql = entityManager
.createQuery(
"select " +
" YEAR(p.createdOn) as year, " +
" count(p) as postCount " +
"from Post p " +
"where p.title like :titleTemplate " +
"group by YEAR(p.createdOn) ",
Tuple.class);
jpql.setParameter("titleTemplate", "%Java%");
return jpql;
}

private static Query createTestCriteriaQuery(EntityManager entityManager) {
CriteriaBuilder builder = entityManager.getCriteriaBuilder();

CriteriaQuery<PostComment> criteria = builder.createQuery(PostComment.class);

Root<PostComment> postComment = criteria.from(PostComment.class);
Join<PostComment, Post> post = postComment.join("post");

criteria.where(
builder.like(post.get("title"), "%Java%")
);

criteria.orderBy(
builder.asc(postComment.get("id"))
);

return entityManager.createQuery(criteria);
}

@Entity(name = "Post")
@Table(name = "post")
public static class Post {
Expand Down Expand Up @@ -161,4 +240,17 @@ public PostComment setReview(String review) {
return this;
}
}

private static class HibernateLikeInvocationHandler implements InvocationHandler {
private final Query target; // has to be named "target" because this is how Hibernate implements it, and the extracting code has to be quite invasive to get the query from the Hibernate proxy

public HibernateLikeInvocationHandler(Query query) {
this.target = query;
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
return method.invoke(target, args);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Date;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
Expand Down Expand Up @@ -103,6 +104,8 @@ public static class Post {
@Type(PostgreSQLEnumType.class)
private PostStatus status;

private Date createdOn;

public Long getId() {
return id;
}
Expand All @@ -126,5 +129,13 @@ public PostStatus getStatus() {
public void setStatus(PostStatus status) {
this.status = status;
}

public Date getCreatedOn() {
return createdOn;
}

public void setCreatedOn(Date createdOn) {
this.createdOn = createdOn;
}
}
}