diff --git a/src/main/org/apache/tools/ant/AntClassLoader.java b/src/main/org/apache/tools/ant/AntClassLoader.java index e7bc33d8c..0eb41fa11 100644 --- a/src/main/org/apache/tools/ant/AntClassLoader.java +++ b/src/main/org/apache/tools/ant/AntClassLoader.java @@ -1,7 +1,7 @@ /* * The Apache Software License, Version 1.1 * - * Copyright (c) 1999 The Apache Software Foundation. All rights + * Copyright (c) 2000-2001 The Apache Software Foundation. All rights * reserved. * * Redistribution and use in source and binary forms, with or without @@ -54,6 +54,7 @@ package org.apache.tools.ant; +import java.lang.reflect.*; import java.util.*; import java.util.zip.*; import java.io.*; @@ -65,7 +66,8 @@ import org.apache.tools.ant.types.Path; * system classpath by using the forceLoadClass method. Any subsequent classes loaded by that * class will then use this loader rather than the system class loader. * - * @author Conor MacNeill + * @author Conor MacNeill + * @author Jesse Glick */ public class AntClassLoader extends ClassLoader { /** @@ -101,6 +103,19 @@ public class AntClassLoader extends ClassLoader { */ private Vector loaderPackages = new Vector(); + private static Method getProtectionDomain = null; + private static Method defineClassProtectionDomain = null; + static { + try { + getProtectionDomain = Class.class.getMethod("getProtectionDomain", new Class[0]); + Class protectionDomain = Class.forName("java.security.ProtectionDomain"); + Class[] args = new Class[] {String.class, byte[].class, Integer.TYPE, Integer.TYPE, protectionDomain}; + defineClassProtectionDomain = ClassLoader.class.getDeclaredMethod("defineClass", args); + } + catch (Exception e) {} + } + + /** * Create a classloader for the given project using the classpath given. * @@ -198,7 +213,7 @@ public class AntClassLoader extends ClassLoader { Class theClass = findLoadedClass(classname); if (theClass == null) { - theClass = findSystemClass(classname); + theClass = findBaseClass(classname); } return theClass; @@ -325,7 +340,7 @@ public class AntClassLoader extends ClassLoader { if (theClass == null) { if (useSystemFirst) { try { - theClass = findSystemClass(classname); + theClass = findBaseClass(classname); project.log("Class " + classname + " loaded from system loader", Project.MSG_DEBUG); } catch (ClassNotFoundException cnfe) { @@ -339,7 +354,7 @@ public class AntClassLoader extends ClassLoader { project.log("Class " + classname + " loaded from ant loader", Project.MSG_DEBUG); } catch (ClassNotFoundException cnfe) { - theClass = findSystemClass(classname); + theClass = findBaseClass(classname); project.log("Class " + classname + " loaded from system loader", Project.MSG_DEBUG); } } @@ -387,10 +402,33 @@ public class AntClassLoader extends ClassLoader { byte[] classData = baos.toByteArray(); - return defineClass(classname, classData, 0, classData.length); + // Simply put: + // defineClass(classname, classData, 0, classData.length, Project.class.getProtectionDomain()); + // Made more elaborate to be 1.1-safe. + if (defineClassProtectionDomain != null) { + try { + Object domain = getProtectionDomain.invoke(Project.class, new Object[0]); + Object[] args = new Object[] {classname, classData, new Integer(0), new Integer(classData.length), domain}; + return (Class)defineClassProtectionDomain.invoke(this, args); + } + catch (InvocationTargetException ite) { + Throwable t = ite.getTargetException(); + if (t instanceof ClassFormatError) { + throw (ClassFormatError)t; + } + else { + throw new IOException(t.toString()); + } + } + catch (Exception e) { + throw new IOException(e.toString()); + } + } + else { + return defineClass(classname, classData, 0, classData.length); + } } - /** * Search for and load a class on the classpath of this class loader. * @@ -447,4 +485,17 @@ public class AntClassLoader extends ClassLoader { catch (IOException e) {} } } + + /** + * Find a system class (which should be loaded from the same classloader as the Ant core). + */ + private Class findBaseClass(String name) throws ClassNotFoundException { + ClassLoader base = AntClassLoader.class.getClassLoader(); + if (base == null) { + return findSystemClass(name); + } + else { + return base.loadClass(name); + } + } }