Add support for JTA transaction scoped entity managers

This commit is contained in:
Stuart Douglas
2018-09-19 12:58:59 +10:00
parent 8996c58d94
commit e7b792346b
13 changed files with 591 additions and 34 deletions

View File

@@ -4,7 +4,7 @@
http://xmlns.jcp.org/xml/ns/persistence/persistence_2_1.xsd"
version="2.1">
<persistence-unit name="templatePU" transaction-type="RESOURCE_LOCAL">
<persistence-unit name="templatePU" transaction-type="JTA">
<description>Hibernate test case template Persistence Unit</description>

View File

@@ -16,6 +16,7 @@ import javax.servlet.annotation.WebServlet;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import javax.transaction.UserTransaction;
/**
* Various tests for the JPA integration.
@@ -26,6 +27,9 @@ public class JPATestEMInjectionEndpoint extends HttpServlet {
@Inject
private EntityManager em;
@Inject
private UserTransaction transaction;
@Override
protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws IOException {
try {
@@ -43,8 +47,7 @@ public class JPATestEMInjectionEndpoint extends HttpServlet {
}
private void doStuffWithHibernate() {
EntityTransaction transaction = em.getTransaction();
private void doStuffWithHibernate() throws Exception {
transaction.begin();
persistNewPerson(em);
@@ -52,7 +55,6 @@ public class JPATestEMInjectionEndpoint extends HttpServlet {
listExistingPersons(em);
transaction.commit();
em.close();
}
private static void listExistingPersons(EntityManager em) {

View File

@@ -4,7 +4,7 @@
http://xmlns.jcp.org/xml/ns/persistence/persistence_2_1.xsd"
version="2.1">
<persistence-unit name="templatePU" transaction-type="RESOURCE_LOCAL">
<persistence-unit name="templatePU" transaction-type="JTA">
<description>Hibernate test case template Persistence Unit</description>

View File

@@ -1,7 +1,7 @@
package org.jboss.protean.gizmo;
public interface FieldCreator extends MemberCreator<FieldCreator> {
public interface FieldCreator extends MemberCreator<FieldCreator>,AnnotatedElement {
FieldDescriptor getFieldDescriptor();

View File

@@ -1,11 +1,18 @@
package org.jboss.protean.gizmo;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.objectweb.asm.AnnotationVisitor;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.FieldVisitor;
import org.objectweb.asm.Opcodes;
class FieldCreatorImpl implements FieldCreator {
private final FieldDescriptor fieldDescriptor;
private final List<AnnotationCreatorImpl> annotations = new ArrayList<>();
private int modifiers;
@@ -32,7 +39,21 @@ class FieldCreatorImpl implements FieldCreator {
@Override
public void write(ClassWriter file) {
file.visitField(modifiers, fieldDescriptor.getName(), fieldDescriptor.getType(), null, null);
FieldVisitor fieldVisitor = file.visitField(modifiers, fieldDescriptor.getName(), fieldDescriptor.getType(), null, null);
for(AnnotationCreatorImpl annotation : annotations) {
AnnotationVisitor av = fieldVisitor.visitAnnotation(DescriptorUtils.extToInt(annotation.getAnnotationType()), true);
for(Map.Entry<String, Object> e : annotation.getValues().entrySet()) {
av.visit(e.getKey(), e.getValue());
}
av.visitEnd();
}
fieldVisitor.visitEnd();
}
@Override
public AnnotationCreator addAnnotation(String annotationType) {
AnnotationCreatorImpl ac = new AnnotationCreatorImpl(annotationType);
annotations.add(ac);
return ac;
}
}

View File

@@ -27,7 +27,13 @@ public final class HibernateEntityEnhancer implements Function<String, Function<
Objects.requireNonNull(classnameWhitelist);
this.classnameWhitelist = classnameWhitelist;
BytecodeProvider provider = new org.hibernate.bytecode.internal.bytebuddy.BytecodeProviderImpl();
this.enhancer = provider.getEnhancer(new DefaultEnhancementContext());
DefaultEnhancementContext enhancementContext = new DefaultEnhancementContext() {
@Override
public ClassLoader getLoadingClassLoader() {
return Thread.currentThread().getContextClassLoader();
}
};
this.enhancer = provider.getEnhancer(enhancementContext);
}
@Override

View File

@@ -35,6 +35,10 @@ final class HibernateReflectiveNeeds {
simpleConstructor(org.hibernate.resource.transaction.backend.jdbc.internal.JdbcResourceLocalTransactionCoordinatorBuilderImpl.class);
simpleConstructor(org.hibernate.id.enhanced.SequenceStyleGenerator.class);
simpleConstructor(org.hibernate.boot.model.naming.ImplicitNamingStrategyJpaCompliantImpl.class);
simpleConstructor(org.hibernate.resource.transaction.backend.jta.internal.JtaTransactionCoordinatorBuilderImpl.class);
processorContext.addReflectiveClass(true, false, com.arjuna.ats.jta.UserTransaction.class.getName());
processorContext.addReflectiveClass(true, false, com.arjuna.ats.jta.TransactionManager.class.getName());
//FIXME following is not Hibernate specific?
simpleConstructor("com.sun.xml.internal.stream.events.XMLEventFactoryImpl");
//ANTLR tokens:

View File

@@ -1,6 +1,7 @@
package org.jboss.shamrock.jpa;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
@@ -22,6 +23,7 @@ import org.jboss.jandex.IndexView;
import org.jboss.shamrock.deployment.ArchiveContext;
import org.jboss.shamrock.deployment.ProcessorContext;
import org.jboss.shamrock.deployment.codegen.BytecodeRecorder;
import org.jboss.shamrock.jpa.runtime.JPADeploymentTemplate;
/**
* Scan the Jandex index to find JPA entities (and embeddables supporting entity models).
@@ -144,7 +146,7 @@ final class JpaJandexScavenger {
}
ClassInfo classInfo = index.getClassByName(className);
if (classInfo == null) {
if (className == ClassType.OBJECT_TYPE.name()) {
if (className == ClassType.OBJECT_TYPE.name() || className.toString().equals(Serializable.class.getName())) {
return;
}
else {

View File

@@ -7,6 +7,7 @@ import java.util.concurrent.atomic.AtomicReference;
import javax.enterprise.context.ApplicationScoped;
import javax.enterprise.context.Dependent;
import javax.enterprise.context.RequestScoped;
import javax.enterprise.inject.Disposes;
import javax.enterprise.inject.Produces;
import javax.inject.Inject;
@@ -15,12 +16,16 @@ import javax.persistence.EntityManagerFactory;
import javax.persistence.Persistence;
import javax.persistence.PersistenceContext;
import javax.persistence.PersistenceUnit;
import javax.transaction.TransactionManager;
import javax.transaction.TransactionSynchronizationRegistry;
import org.jboss.jandex.AnnotationInstance;
import org.jboss.jandex.AnnotationTarget;
import org.jboss.jandex.AnnotationValue;
import org.jboss.jandex.DotName;
import org.jboss.protean.gizmo.ClassCreator;
import org.jboss.protean.gizmo.ClassOutput;
import org.jboss.protean.gizmo.FieldCreator;
import org.jboss.protean.gizmo.FieldDescriptor;
import org.jboss.protean.gizmo.MethodCreator;
import org.jboss.protean.gizmo.MethodDescriptor;
@@ -30,6 +35,8 @@ import org.jboss.shamrock.deployment.BeanArchiveIndex;
import org.jboss.shamrock.deployment.BeanDeployment;
import org.jboss.shamrock.deployment.ProcessorContext;
import org.jboss.shamrock.deployment.ResourceProcessor;
import org.jboss.shamrock.jpa.runtime.cdi.SystemEntityManager;
import org.jboss.shamrock.jpa.runtime.cdi.TransactionScopedEntityManager;
public class HibernateCdiResourceProcessor implements ResourceProcessor {
@@ -45,6 +52,7 @@ public class HibernateCdiResourceProcessor implements ResourceProcessor {
@Override
public void process(ArchiveContext archiveContext, ProcessorContext processorContext) throws Exception {
Set<String> knownUnitNames = new HashSet<>();
Set<String> knownContextNames = new HashSet<>();
scanForAnnotations(archiveContext, knownUnitNames, PERSISTENCE_UNIT);
@@ -57,7 +65,7 @@ public class HibernateCdiResourceProcessor implements ResourceProcessor {
Set<String> allKnownNames = new HashSet<>(knownUnitNames);
allKnownNames.addAll(knownContextNames);
for (String name : knownContextNames) {
for (String name : allKnownNames) {
String className = getClass().getName() + "$$EMFProducer-" + name;
AtomicReference<byte[]> bytes = new AtomicReference<>();
try (ClassCreator creator = new ClassCreator(new InMemoryClassOutput(bytes, processorContext), className, null, Object.class.getName())) {
@@ -80,27 +88,66 @@ public class HibernateCdiResourceProcessor implements ResourceProcessor {
}
for (String name : knownUnitNames) {
for (String name : knownContextNames) {
String className = getClass().getName() + "$$EMProducer-" + name;
AtomicReference<byte[]> bytes = new AtomicReference<>();
//we need to know if transactions are present or not
// if (processorContext.isCapabilityPresent("transactions")) {
//
// } else {
//TODO: this should be based on if a PU is JTA enabled or not
if (processorContext.isCapabilityPresent("transactions")) {
try (ClassCreator creator = new ClassCreator(new InMemoryClassOutput(bytes, processorContext), className, null, Object.class.getName())) {
creator.addAnnotation(Dependent.class);
FieldCreator emfField = creator.getFieldCreator("emf", EntityManagerFactory.class);
emfField.addAnnotation(Inject.class);
if (!knownUnitNames.contains(name)) {
emfField.addAnnotation(SystemEntityManager.class);
}
FieldDescriptor emf = emfField.getFieldDescriptor();
FieldCreator tsrField = creator.getFieldCreator("tsr", TransactionSynchronizationRegistry.class);
tsrField.addAnnotation(Inject.class);
FieldDescriptor tsr = tsrField.getFieldDescriptor();
FieldCreator tmField = creator.getFieldCreator("tm", TransactionManager.class);
tmField.addAnnotation(Inject.class);
FieldDescriptor tm = tmField.getFieldDescriptor();
MethodCreator producer = creator.getMethodCreator("producerMethod", EntityManager.class);
producer.addAnnotation(Produces.class);
producer.addAnnotation(RequestScoped.class);
ResultHandle emfRh = producer.readInstanceField(emf, producer.getThis());
ResultHandle tsrRh = producer.readInstanceField(tsr, producer.getThis());
ResultHandle tmRh = producer.readInstanceField(tm, producer.getThis());
producer.returnValue(producer.newInstance(MethodDescriptor.ofConstructor(TransactionScopedEntityManager.class, TransactionManager.class, TransactionSynchronizationRegistry.class, EntityManagerFactory.class), tmRh, tsrRh, emfRh));
MethodCreator disposer = creator.getMethodCreator("disposerMethod", void.class, EntityManager.class);
disposer.getParameterAnnotations(0).addAnnotation(Disposes.class);
disposer.invokeVirtualMethod(MethodDescriptor.ofMethod(TransactionScopedEntityManager.class, "requestDone", void.class), disposer.getMethodParam(0));
disposer.returnValue(null);
}
beanDeployment.addGeneratedBean(className, bytes.get());
} else {
//if there is no TX support then we just use a super simple approach, and produce a normal EM
try (ClassCreator creator = new ClassCreator(new InMemoryClassOutput(bytes, processorContext), className, null, Object.class.getName())) {
creator.addAnnotation(Dependent.class);
FieldDescriptor emf = creator.getFieldCreator("emf", EntityManagerFactory.class).getFieldDescriptor();
MethodCreator setter = creator.getMethodCreator("setEmf", void.class, EntityManagerFactory.class);
setter.writeInstanceField(emf, setter.getThis(), setter.getMethodParam(0));
setter.addAnnotation(Inject.class);
FieldCreator emfField = creator.getFieldCreator("emf", EntityManagerFactory.class);
emfField.addAnnotation(Inject.class);
if (!knownUnitNames.contains(name)) {
setter.addAnnotation(SystemEntityManager.class);
emfField.addAnnotation(SystemEntityManager.class);
}
setter.returnValue(null);
FieldDescriptor emf = emfField.getFieldDescriptor();
MethodCreator producer = creator.getMethodCreator("producerMethod", EntityManager.class);
producer.addAnnotation(Produces.class);
@@ -117,28 +164,32 @@ public class HibernateCdiResourceProcessor implements ResourceProcessor {
}
beanDeployment.addGeneratedBean(className, bytes.get());
// }
}
}
}
private void scanForAnnotations(ArchiveContext archiveContext, Set<String> knownUnitNames, DotName nm) {
for (AnnotationInstance anno : archiveContext.getCombinedIndex().getAnnotations(nm)) {
AnnotationValue unitName = anno.value("unitName");
if(unitName == null) {
continue;
}
if (anno.target().kind() == AnnotationTarget.Kind.METHOD) {
if (anno.target().asMethod().hasAnnotation(PRODUCES)) {
knownUnitNames.add(anno.value("unitName").asString());
knownUnitNames.add(unitName.asString());
}
} else if (anno.target().kind() == AnnotationTarget.Kind.FIELD) {
for (AnnotationInstance i : anno.target().asField().annotations()) {
if (i.name().equals(PRODUCES)) {
knownUnitNames.add(anno.value("unitName").asString());
knownUnitNames.add(unitName.asString());
break;
}
}
} else if (anno.target().kind() == AnnotationTarget.Kind.CLASS) {
for (AnnotationInstance i : anno.target().asClass().classAnnotations()) {
if (i.name().equals(PRODUCES)) {
knownUnitNames.add(anno.value("unitName").asString());
knownUnitNames.add(unitName.asString());
break;
}
}

View File

@@ -1,4 +1,4 @@
package org.jboss.shamrock.jpa;
package org.jboss.shamrock.jpa.runtime;
import org.hibernate.protean.Hibernate;

View File

@@ -1,4 +1,4 @@
package org.jboss.shamrock.jpa.cdi;
package org.jboss.shamrock.jpa.runtime.cdi;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;

View File

@@ -0,0 +1,460 @@
package org.jboss.shamrock.jpa.runtime.cdi;
import java.util.List;
import java.util.Map;
import javax.persistence.EntityGraph;
import javax.persistence.EntityManager;
import javax.persistence.EntityManagerFactory;
import javax.persistence.EntityTransaction;
import javax.persistence.FlushModeType;
import javax.persistence.LockModeType;
import javax.persistence.Query;
import javax.persistence.StoredProcedureQuery;
import javax.persistence.TypedQuery;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaDelete;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.CriteriaUpdate;
import javax.persistence.metamodel.Metamodel;
import javax.transaction.Status;
import javax.transaction.Synchronization;
import javax.transaction.TransactionManager;
import javax.transaction.TransactionSynchronizationRegistry;
public class TransactionScopedEntityManager implements EntityManager {
private final TransactionManager transactionManager;
private final TransactionSynchronizationRegistry tsr;
private final EntityManagerFactory emf;
private static final Object transactionKey = new Object();
private EntityManager fallbackEntityManager;
public TransactionScopedEntityManager(TransactionManager transactionManager, TransactionSynchronizationRegistry tsr, EntityManagerFactory emf) {
this.transactionManager = transactionManager;
this.tsr = tsr;
this.emf = emf;
}
public void requestDone() {
if(fallbackEntityManager != null) {
fallbackEntityManager.close();
}
}
EntityManagerResult getEntityManager() {
if (isInTransaction()) {
EntityManager em = (EntityManager) tsr.getResource(transactionKey);
if (em != null) {
return new EntityManagerResult(em, false);
}
EntityManager newEm = emf.createEntityManager();
tsr.putResource(transactionKey, newEm);
tsr.registerInterposedSynchronization(new Synchronization() {
@Override
public void beforeCompletion() {
newEm.flush();
newEm.close();
}
@Override
public void afterCompletion(int i) {
}
});
return new EntityManagerResult(newEm, false);
} else {
if(fallbackEntityManager == null) {
fallbackEntityManager = emf.createEntityManager();
}
return new EntityManagerResult(emf.createEntityManager(), false);
}
}
private boolean isInTransaction() {
try {
switch (transactionManager.getStatus()) {
case Status.STATUS_ACTIVE:
case Status.STATUS_COMMITTING:
case Status.STATUS_MARKED_ROLLBACK:
case Status.STATUS_PREPARED:
case Status.STATUS_PREPARING:
return true;
default:
return false;
}
} catch (Exception e) {
throw new RuntimeException(e);
}
}
@Override
public void persist(Object entity) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.persist(entity);
}
}
@Override
public <T> T merge(T entity) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.merge(entity);
}
}
@Override
public void remove(Object entity) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.remove(entity);
}
}
@Override
public <T> T find(Class<T> entityClass, Object primaryKey) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.find(entityClass, primaryKey);
}
}
@Override
public <T> T find(Class<T> entityClass, Object primaryKey, Map<String, Object> properties) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.find(entityClass, primaryKey, properties);
}
}
@Override
public <T> T find(Class<T> entityClass, Object primaryKey, LockModeType lockMode) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.find(entityClass, primaryKey, lockMode);
}
}
@Override
public <T> T find(Class<T> entityClass, Object primaryKey, LockModeType lockMode, Map<String, Object> properties) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.find(entityClass, primaryKey, lockMode, properties);
}
}
@Override
public <T> T getReference(Class<T> entityClass, Object primaryKey) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getReference(entityClass, primaryKey);
}
}
@Override
public void flush() {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.flush();
}
}
@Override
public void setFlushMode(FlushModeType flushMode) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.setFlushMode(flushMode);
}
}
@Override
public FlushModeType getFlushMode() {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getFlushMode();
}
}
@Override
public void lock(Object entity, LockModeType lockMode) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.lock(entity, lockMode);
}
}
@Override
public void lock(Object entity, LockModeType lockMode, Map<String, Object> properties) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.lock(entity, lockMode, properties);
}
}
@Override
public void refresh(Object entity) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.refresh(entity);
}
}
@Override
public void refresh(Object entity, Map<String, Object> properties) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.refresh(entity, properties);
}
}
@Override
public void refresh(Object entity, LockModeType lockMode) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.refresh(entity, lockMode);
}
}
@Override
public void refresh(Object entity, LockModeType lockMode, Map<String, Object> properties) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.refresh(entity, lockMode, properties);
}
}
@Override
public void clear() {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.clear();
}
}
@Override
public void detach(Object entity) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.detach(entity);
}
}
@Override
public boolean contains(Object entity) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.contains(entity);
}
}
@Override
public LockModeType getLockMode(Object entity) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getLockMode(entity);
}
}
@Override
public void setProperty(String propertyName, Object value) {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.setProperty(propertyName, value);
}
}
@Override
public Map<String, Object> getProperties() {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getProperties();
}
}
@Override
public Query createQuery(String qlString) {
//TODO: this needs some thought for how it works outside a tx
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createQuery(qlString);
}
}
@Override
public <T> TypedQuery<T> createQuery(CriteriaQuery<T> criteriaQuery) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createQuery(criteriaQuery);
}
}
@Override
public Query createQuery(CriteriaUpdate updateQuery) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createQuery(updateQuery);
}
}
@Override
public Query createQuery(CriteriaDelete deleteQuery) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createQuery(deleteQuery);
}
}
@Override
public <T> TypedQuery<T> createQuery(String qlString, Class<T> resultClass) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createQuery(qlString, resultClass);
}
}
@Override
public Query createNamedQuery(String name) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createNamedQuery(name);
}
}
@Override
public <T> TypedQuery<T> createNamedQuery(String name, Class<T> resultClass) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createNamedQuery(name, resultClass);
}
}
@Override
public Query createNativeQuery(String sqlString) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createNativeQuery(sqlString);
}
}
@Override
public Query createNativeQuery(String sqlString, Class resultClass) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createNativeQuery(sqlString, resultClass);
}
}
@Override
public Query createNativeQuery(String sqlString, String resultSetMapping) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createNativeQuery(sqlString, resultSetMapping);
}
}
@Override
public StoredProcedureQuery createNamedStoredProcedureQuery(String name) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createNamedStoredProcedureQuery(name);
}
}
@Override
public StoredProcedureQuery createStoredProcedureQuery(String procedureName) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createStoredProcedureQuery(procedureName);
}
}
@Override
public StoredProcedureQuery createStoredProcedureQuery(String procedureName, Class... resultClasses) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createStoredProcedureQuery(procedureName, resultClasses);
}
}
@Override
public StoredProcedureQuery createStoredProcedureQuery(String procedureName, String... resultSetMappings) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createStoredProcedureQuery(procedureName, resultSetMappings);
}
}
@Override
public void joinTransaction() {
try (EntityManagerResult emr = getEntityManager()) {
emr.em.joinTransaction();
}
}
@Override
public boolean isJoinedToTransaction() {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.isJoinedToTransaction();
}
}
@Override
public <T> T unwrap(Class<T> cls) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.unwrap(cls);
}
}
@Override
public Object getDelegate() {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getDelegate();
}
}
@Override
public void close() {
throw new IllegalStateException("Not supported for transaction scoped entity managers");
}
@Override
public boolean isOpen() {
return true;
}
@Override
public EntityTransaction getTransaction() {
throw new IllegalStateException("Not supported for JTA entity managers");
}
@Override
public EntityManagerFactory getEntityManagerFactory() {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getEntityManagerFactory();
}
}
@Override
public CriteriaBuilder getCriteriaBuilder() {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getCriteriaBuilder();
}
}
@Override
public Metamodel getMetamodel() {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getMetamodel();
}
}
@Override
public <T> EntityGraph<T> createEntityGraph(Class<T> rootType) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createEntityGraph(rootType);
}
}
@Override
public EntityGraph<?> createEntityGraph(String graphName) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.createEntityGraph(graphName);
}
}
@Override
public EntityGraph<?> getEntityGraph(String graphName) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getEntityGraph(graphName);
}
}
@Override
public <T> List<EntityGraph<? super T>> getEntityGraphs(Class<T> entityClass) {
try (EntityManagerResult emr = getEntityManager()) {
return emr.em.getEntityGraphs(entityClass);
}
}
static class EntityManagerResult implements AutoCloseable {
private final EntityManager em;
private final boolean closeOnEnd;
EntityManagerResult(EntityManager em, boolean closeOnEnd) {
this.em = em;
this.closeOnEnd = closeOnEnd;
}
@Override
public void close() {
if (closeOnEnd) {
em.close();
}
}
}
}

View File

@@ -17,6 +17,7 @@ public class ShamrockTest extends BlockJUnit4ClassRunner {
private static boolean first = true;
private static boolean started = false;
private static boolean failed = false;
/**
* Creates a BlockJUnit4ClassRunner to run {@code klass}
@@ -42,19 +43,29 @@ public class ShamrockTest extends BlockJUnit4ClassRunner {
try {
notifier.addListener(new RunListener() {
@Override
public void testStarted(Description description) {
if (failed) {
notifier.fireTestFailure(new Failure(description, new AssertionError("Startup failed")));
return;
}
if (!started) {
started = true;
//TODO: so much hacks...
Class<?> theClass = description.getTestClass();
String classFileName = theClass.getName().replace(".", "/") + ".class";
URL resource = theClass.getClassLoader().getResource(classFileName);
String testClassLocation = resource.getPath().substring(0, resource.getPath().length() - classFileName.length());
String appClassLocation = testClassLocation.replace("test-classes", "classes");
Path appRoot = Paths.get(appClassLocation);
RuntimeRunner runtimeRunner = new RuntimeRunner(getClass().getClassLoader(), appRoot, Paths.get(testClassLocation), new ArchiveContextBuilder());
runtimeRunner.run();
try {
Class<?> theClass = description.getTestClass();
String classFileName = theClass.getName().replace(".", "/") + ".class";
URL resource = theClass.getClassLoader().getResource(classFileName);
String testClassLocation = resource.getPath().substring(0, resource.getPath().length() - classFileName.length());
String appClassLocation = testClassLocation.replace("test-classes", "classes");
Path appRoot = Paths.get(appClassLocation);
RuntimeRunner runtimeRunner = new RuntimeRunner(getClass().getClassLoader(), appRoot, Paths.get(testClassLocation), new ArchiveContextBuilder());
runtimeRunner.run();
} catch (RuntimeException e) {
failed = true;
throw e;
}
}
}
});