001 package edu.rice.cs.cunit.instrumentors.threadCheck;
002
003 import edu.rice.cs.cunit.classFile.ClassFile;
004 import edu.rice.cs.cunit.classFile.ClassFileTools;
005 import edu.rice.cs.cunit.classFile.MethodInfo;
006 import edu.rice.cs.cunit.classFile.attributes.CodeAttributeInfo;
007 import edu.rice.cs.cunit.classFile.code.InstructionList;
008 import edu.rice.cs.cunit.classFile.code.Opcode;
009 import edu.rice.cs.cunit.classFile.code.instructions.GenericInstruction;
010 import edu.rice.cs.cunit.classFile.code.instructions.ReferenceInstruction;
011 import edu.rice.cs.cunit.classFile.constantPool.*;
012 import edu.rice.cs.cunit.classFile.constantPool.visitors.ADefaultPoolInfoVisitor;
013 import edu.rice.cs.cunit.classFile.constantPool.visitors.CheckMethodVisitor;
014 import edu.rice.cs.cunit.classFile.constantPool.visitors.CheckUTFVisitor;
015 import edu.rice.cs.cunit.threadCheck.OnlyRunBy;
016 import edu.rice.cs.cunit.threadCheck.ThreadCheck;
017 import edu.rice.cs.cunit.util.Types;
018
019 import java.io.IOException;
020 import java.util.ArrayList;
021 import java.util.List;
022
023 /**
024 * Instrumentor to add calls to ThreadCheck.checkCurrentThreadName/Id/Group to check if the current thread is not
025 * allowed to execute a class or method.
026 * <p/>
027 * This instrumentor checks for every method if there are @NotRunBy or @OnlyRunBy annotations attached to the method,
028 * the containing class, the same method in one of the superclasses or interfaces, or a superclass or interface, and
029 * then at the beginning of the method inserts calls to ThreadCheck..
030 *
031 * @author Mathias Ricken
032 */
033 public class AddThreadCheckStrategy extends AAddThreadCheckStrategy {
034 /**
035 * Constructor for this strategy.
036 * @param shared data shared among all AThreadCheckStrategy instances
037 * @param sharedAdd data for all AAddThreadCheckStrategy instances
038 */
039 public AddThreadCheckStrategy(SharedData shared, SharedAddData sharedAdd) {
040 this(new ArrayList<String>(), shared, sharedAdd);
041 }
042
043 /**
044 * Constructor for this strategy.
045 * @param parameters parameters for the instrumentors
046 * @param shared data shared among all AThreadCheckStrategy instances
047 * @param sharedAdd data for all AAddThreadCheckStrategy instances
048 */
049 public AddThreadCheckStrategy(List<String> parameters, SharedData shared, SharedAddData sharedAdd) {
050 super(parameters, shared, sharedAdd);
051 }
052
053 /**
054 * Instrument the class.
055 *
056 * @param cf class file info
057 */
058 public void instrument(final ClassFile cf) {
059 _sharedData.setCurrentClassName(cf.getThisClassName());
060 ConstantPool cp = cf.getConstantPool();
061 ReferenceInstruction checkCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
062 ReferenceInstruction loadStringInstr = new ReferenceInstruction(Opcode.LDC_W, (short)0);
063 ReferenceInstruction loadLongInstr = new ReferenceInstruction(Opcode.LDC2_W, (short)0);
064 ReferenceInstruction okCheckCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
065 ReferenceInstruction addNameCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
066 ReferenceInstruction addIdCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
067 ReferenceInstruction addGroupCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
068 ReferenceInstruction setEventThreadCallInstr = new ReferenceInstruction(Opcode.INVOKESTATIC, (short)0);
069 ReferenceInstruction isDisplayableCallInstr = new ReferenceInstruction(Opcode.INVOKEVIRTUAL, (short)0);
070 int checkNameCallIndex = 0;
071 int checkIdCallIndex = 0;
072 int checkGroupCallIndex = 0;
073 int okCheckCallIndex = 0;
074 int addNameCallIndex = 0;
075 int addGroupCallIndex = 0;
076 int addIdCallIndex = 0;
077 int setEventThreadIndex = 0;
078 int isDisplayableCallIndex = 0;
079
080 // process all methods in this class
081 for(MethodInfo mi : cf.getMethods()) {
082 // proces if not a native or abstract method, should have a body
083 if ((mi.getAccessFlags() & (ClassFile.ACC_NATIVE | ClassFile.ACC_ABSTRACT)) == 0) {
084 long beginMillis = System.currentTimeMillis();
085 ThreadCheckAnnotationRecord methodAR = getMethodAnnotations(cf, mi);
086 long endMillis = System.currentTimeMillis();
087 _sharedAddData.cacheInfo.addTimeSpent(endMillis-beginMillis);
088
089 if (!methodAR.empty()) {
090 boolean changed = false;
091 InstructionList il = new InstructionList(mi.getCodeAttributeInfo().getCode());
092
093 // if this is a constructor ("<init>"), then we should perhaps wait until
094 // after the this() or super() call; try to find it
095 if (mi.getName().toString().equals("<init>")) {
096 // check if this ctor calls other ctors
097 boolean ctorCalled = false;
098 do {
099 if (il.getOpcode() == Opcode.INVOKESPECIAL) {
100 ReferenceInstruction ri = (ReferenceInstruction)il.getInstr();
101 short method = Types.shortFromBytes(ri.getBytecode(), 1);
102 MethodPoolInfo mpi = cf.getConstantPoolItem(method).execute(CheckMethodVisitor.singleton(), null);
103 if (mpi.getNameAndType().getName().toString().equals("<init>")) {
104 ClassFile curcf = cf;
105 while((curcf.getThisClassName()!=null) && (!curcf.getThisClassName().equals(""))) {
106 if (curcf.getThisClass().toString().equals(mpi.getClassInfo().getName().toString())) {
107 // super() call
108 ctorCalled = true;
109 break;
110 }
111 ClassFileTools.ClassLocation cl = null;
112 try {
113 cl = ClassFileTools.findClassFile(curcf.getSuperClassName(), _sharedData.getClassPath());
114 if (cl != null) {
115 curcf = cl.getClassFile();
116 }
117 else {
118 _sharedData.addClassNotFoundWarning(new ClassNotFoundWarning(curcf.getSuperClassName(),_sharedData.getCurrentClassName()));
119 //Debug.out.println("Warning: Could not find " + curcf.getSuperClassName());
120 break;
121 }
122 }
123 finally {
124 try { if (cl!=null) cl.close(); }
125 catch(IOException e) { /* ignore; shouldn't cause any problems except on Windows with read locks */ }
126 }
127 }
128 if (ctorCalled) { break; }
129 }
130 }
131 } while(il.advanceIndex());
132 if (ctorCalled) {
133 // super call found
134 // advance one past the super call
135 boolean res = il.advanceIndex();
136 assert res == true;
137 }
138 else {
139 // no this() or super() call found, start at the beginning
140 il.setIndex(0);
141 _sharedAddData.otherWarnings.add(new OnlyAfterRealizedWarning("ignored, no this() or super() call found in constructor "+
142 cf.getThisClassName()+"."+mi.getName()+mi.getDescriptor()));
143 methodAR.allowEventThread = OnlyRunBy.EVENT_THREAD.NO;
144 }
145 }
146
147 // ========
148 // NotRunBy
149 // ========
150
151 // add checks for thread names from NotRunBy
152 for(String s: methodAR.denyThreadNames) {
153 changed = true;
154
155 if (checkNameCallIndex==0) {
156 checkNameCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
157 "checkCurrentThreadName",
158 "(Ljava/lang/String;)V");
159 }
160 checkCallInstr.setReference(checkNameCallIndex);
161 il.insertInstr(checkCallInstr, mi.getCodeAttributeInfo());
162
163 // add the string
164 AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
165 int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
166 utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
167
168 StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
169 l = cf.addConstantPoolItems(new APoolInfo[]{spi});
170 spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
171 public StringPoolInfo defaultCase(APoolInfo host, Object o) {
172 throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
173 }
174 public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
175 return host;
176 }
177 }, null);
178
179 loadStringInstr.setReference(l[0]);
180 il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
181
182 // now we have inserted the following bytecode:
183 // ldc_w (string)
184 // invokestatic (checkCurrentThreadName)
185 }
186
187 // add checks for thread ids from NotRunBy
188 for(long id: methodAR.denyThreadIds) {
189 changed = true;
190 if (checkIdCallIndex==0) {
191 checkIdCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
192 "checkCurrentThreadId",
193 "(J)V");
194 }
195 checkCallInstr.setReference(checkIdCallIndex);
196 il.insertInstr(checkCallInstr, mi.getCodeAttributeInfo());
197
198 // add the long
199 loadLongInstr.setReference(cf.addLongToConstantPool(id));
200
201 il.insertInstr(loadLongInstr, mi.getCodeAttributeInfo());
202 }
203
204 // add checks for thread groups from NotRunBy
205 for(String s: methodAR.denyThreadGroups) {
206 changed = true;
207 if (checkGroupCallIndex==0) {
208 checkGroupCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
209 "checkCurrentThreadGroup",
210 "(Ljava/lang/String;)V");
211 }
212 checkCallInstr.setReference(checkGroupCallIndex);
213 il.insertInstr(checkCallInstr, mi.getCodeAttributeInfo());
214
215 // add the string
216 AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
217 int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
218 utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
219
220 StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
221 l = cf.addConstantPoolItems(new APoolInfo[]{spi});
222 spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
223 public StringPoolInfo defaultCase(APoolInfo host, Object o) {
224 throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
225 }
226 public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
227 return host;
228 }
229 }, null);
230
231 loadStringInstr.setReference(l[0]);
232 il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
233 }
234
235 // =========
236 // OnlyRunBy
237 // =========
238
239 int lengthBeforeOnlyRunBy = il.getLength();
240
241 // add checks for event thread from OnlyRunBy
242 if (methodAR.allowEventThread!=OnlyRunBy.EVENT_THREAD.NO) {
243 boolean insertInstr = true;
244 if(methodAR.allowEventThread==OnlyRunBy.EVENT_THREAD.ONLY_AFTER_REALIZED) {
245 if ((mi.getAccessFlags() & ClassFile.ACC_STATIC) != 0) {
246 _sharedAddData.otherWarnings.add(new OnlyAfterRealizedWarning(
247 "ignored for static method "+
248 cf.getThisClassName()+"."+mi.getName()+mi.getDescriptor()));
249 insertInstr = false;
250 }
251 else {
252 // we need to check if this is a subclass of java.awt.Component
253 boolean found = false;
254 ClassFile curcf = cf;
255 while((curcf.getThisClassName()!=null) && (!curcf.getThisClassName().equals(""))) {
256 if (curcf.getThisClassName().equals("java.awt.Component")) {
257 found = true;
258 break;
259 }
260 ClassFileTools.ClassLocation cl = null;
261 try {
262 cl = ClassFileTools.findClassFile(curcf.getSuperClassName(), _sharedData.getClassPath());
263 if (cl != null) {
264 curcf = cl.getClassFile();
265 }
266 else {
267 _sharedData.addClassNotFoundWarning(new ClassNotFoundWarning(curcf.getSuperClassName(),_sharedData.getCurrentClassName()));
268 //Debug.out.println("Warning: Could not find " + curcf.getSuperClassName());
269 break;
270 }
271 }
272 finally {
273 try { if (cl!=null) cl.close(); }
274 catch(IOException e) { /* ignore; shouldn't cause any problems except on Windows with read locks */ }
275 }
276 }
277 if (!found) {
278 _sharedAddData.otherWarnings.add(new OnlyAfterRealizedWarning(
279 "class not subclass of java.awt.Component, ignored for method "+
280 cf.getThisClassName()+"."+mi.getName()+mi.getDescriptor()));
281 insertInstr = false;
282 }
283 }
284 }
285
286 if (insertInstr) {
287 changed = true;
288
289 if (setEventThreadIndex==0) {
290 setEventThreadIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
291 "setAllowedEventThread_OnlyRunBy",
292 "(Z)V");
293 }
294 setEventThreadCallInstr.setReference(setEventThreadIndex);
295 il.insertInstr(setEventThreadCallInstr, mi.getCodeAttributeInfo());
296
297 if(methodAR.allowEventThread==OnlyRunBy.EVENT_THREAD.ONLY_AFTER_REALIZED) {
298 if (isDisplayableCallIndex==0) {
299 isDisplayableCallIndex = cf.addMethodToConstantPool("java/awt/Component",
300 "isDisplayable",
301 "()Z");
302 }
303 isDisplayableCallInstr.setReference(isDisplayableCallIndex);
304 il.insertInstr(isDisplayableCallInstr, mi.getCodeAttributeInfo());
305 il.insertInstr(new GenericInstruction(Opcode.ALOAD_0), mi.getCodeAttributeInfo());
306 }
307 else {
308 il.insertInstr(new GenericInstruction(Opcode.ICONST_1), mi.getCodeAttributeInfo());
309 }
310 }
311 }
312
313 // add checks for thread groups from OnlyRunBy
314 for(String s: methodAR.allowThreadNames) {
315 changed = true;
316
317 if (addNameCallIndex==0) {
318 addNameCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
319 "addAllowedName_OnlyRunBy",
320 "(Ljava/lang/String;)V");
321 }
322 addNameCallInstr.setReference(addNameCallIndex);
323 il.insertInstr(addNameCallInstr, mi.getCodeAttributeInfo());
324
325 // add the string
326 AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
327 int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
328 utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
329
330 StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
331 l = cf.addConstantPoolItems(new APoolInfo[]{spi});
332 spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
333 public StringPoolInfo defaultCase(APoolInfo host, Object o) {
334 throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
335 }
336 public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
337 return host;
338 }
339 }, null);
340
341 loadStringInstr.setReference(l[0]);
342 il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
343
344 // now we have inserted the following bytecode:
345 // ldc_w (string)
346 // invokestatic (addAllowedName_OnlyRunBy)
347 }
348
349 // add checks for thread ids from OnlyRunBy
350 for(long id: methodAR.allowThreadIds) {
351 changed = true;
352
353 if (addIdCallIndex==0) {
354 addIdCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
355 "addAllowedId_OnlyRunBy",
356 "(J)V");
357 }
358 addIdCallInstr.setReference(addIdCallIndex);
359 il.insertInstr(addIdCallInstr, mi.getCodeAttributeInfo());
360
361 // add the long
362 loadLongInstr.setReference(cf.addLongToConstantPool(id));
363
364 il.insertInstr(loadLongInstr, mi.getCodeAttributeInfo());
365
366 // now we have inserted the following bytecode:
367 // ldc2_w (id)
368 // invokestatic (addAllowedId_OnlyRunBy)
369 }
370
371 // add checks for thread groups from OnlyRunBy
372 for(String s: methodAR.allowThreadGroups) {
373 changed = true;
374
375 if (addGroupCallIndex==0) {
376 addGroupCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
377 "addAllowedGroup_OnlyRunBy",
378 "(Ljava/lang/String;)V");
379 }
380 addGroupCallInstr.setReference(addGroupCallIndex);
381 il.insertInstr(addGroupCallInstr, mi.getCodeAttributeInfo());
382
383 // add the string
384 AUTFPoolInfo utfpi = new ASCIIPoolInfo(s, cp);
385 int[] l = cf.addConstantPoolItems(new APoolInfo[]{utfpi});
386 utfpi = cf.getConstantPoolItem(l[0]).execute(CheckUTFVisitor.singleton(), null);
387
388 StringPoolInfo spi = new StringPoolInfo(utfpi, cp);
389 l = cf.addConstantPoolItems(new APoolInfo[]{spi});
390 spi = cf.getConstantPoolItem(l[0]).execute(new ADefaultPoolInfoVisitor<StringPoolInfo, Object>() {
391 public StringPoolInfo defaultCase(APoolInfo host, Object o) {
392 throw new ClassFormatError("Info is of type " + host.getClass().getName() + ", needs to be StringPoolInfo");
393 }
394 public StringPoolInfo stringCase(StringPoolInfo host, Object o) {
395 return host;
396 }
397 }, null);
398
399 loadStringInstr.setReference(l[0]);
400 il.insertInstr(loadStringInstr, mi.getCodeAttributeInfo());
401
402 // now we have inserted the following bytecode:
403 // ldc_w (string)
404 // invokestatic (addAllowedName_OnlyRunBy)
405 }
406
407 int lengthAfterOnlyRunBy = il.getLength();
408 if (lengthAfterOnlyRunBy>lengthBeforeOnlyRunBy) {
409 // add check call for OnlyRunBy
410 if (okCheckCallIndex==0) {
411 okCheckCallIndex = cf.addMethodToConstantPool(ThreadCheck.class.getName().replace('.','/'),
412 "checkCurrentThread_OnlyRunBy",
413 "()V");
414 }
415 okCheckCallInstr.setReference(okCheckCallIndex);
416 il.advanceIndex(lengthAfterOnlyRunBy-lengthBeforeOnlyRunBy);
417 il.insertInstr(okCheckCallInstr, mi.getCodeAttributeInfo());
418 // invokestatic (checkCurrentThread_OnlyRunBy)
419 }
420
421 if (changed) {
422 // write code back
423 mi.getCodeAttributeInfo().setCode(il.getCode());
424
425 // make sure we have at least five slots on the stack
426 CodeAttributeInfo.CodeProperties cProps = mi.getCodeAttributeInfo().getProperties();
427 cProps.maxStack = Math.max(5, cProps.maxStack);
428 mi.getCodeAttributeInfo().setProperties(cProps.maxStack, cProps.maxLocals);
429 }
430 }
431 }
432 }
433 }
434 }