001    /*
002     * $Id: TaggedSocketChannel.java 3297 2010-08-26 19:09:03Z kredel $
003     */
004    
005    package edu.jas.util;
006    
007    
008    import java.io.IOException;
009    import java.io.Serializable;
010    import java.util.HashMap;
011    import java.util.Map;
012    import java.util.Map.Entry;
013    import java.util.concurrent.BlockingQueue;
014    import java.util.concurrent.LinkedBlockingQueue;
015    import java.util.concurrent.atomic.AtomicInteger;
016    
017    import org.apache.log4j.Logger;
018    
019    
020    /**
021     * TaggedSocketChannel provides a communication channel with message tags for
022     * Java objects using TCP/IP sockets.
023     * @author Heinz Kredel.
024     */
025    public class TaggedSocketChannel extends Thread {
026    
027    
028        private static final Logger logger = Logger.getLogger(TaggedSocketChannel.class);
029    
030    
031        private static final boolean debug = logger.isDebugEnabled();
032    
033    
034        /**
035         * Flag if receiver is running.
036         */
037        private volatile boolean isRunning = false;
038    
039    
040        /**
041         * End message.
042         */
043        private final static String DONE = "TaggedSocketChannel Done";
044    
045    
046        /**
047         * Blocked threads count.
048         */
049        private final AtomicInteger blockedCount;
050    
051    
052        /**
053         * Underlying socket channel.
054         */
055        protected final SocketChannel sc;
056    
057    
058        /**
059         * Queues for each message tag.
060         */
061        protected final Map<Integer, BlockingQueue> queues;
062    
063    
064        /**
065         * Constructs a tagged socket channel on the given socket channel s.
066         * @param s A socket channel object.
067         */
068        public TaggedSocketChannel(SocketChannel s) {
069            sc = s;
070            blockedCount = new AtomicInteger(0);
071            queues = new HashMap<Integer, BlockingQueue>();
072        }
073    
074    
075        /**
076         * thread initialization and start.
077         */
078        public void init() {
079            synchronized (queues) {
080                if ( ! isRunning ) {
081                    this.start();
082                    isRunning = true;
083                }
084            }
085            logger.info("TaggedSocketChannel at " + sc);
086        }
087    
088    
089        /**
090         * Get the SocketChannel
091         */
092        public SocketChannel getSocket() {
093            return sc;
094        }
095    
096    
097        /**
098         * Sends an object.
099         * @param tag message tag
100         * @param v object to send
101         * @throws IOException
102         */
103        public void send(Integer tag, Object v) throws IOException {
104            if (tag == null) {
105                throw new IllegalArgumentException("tag " + tag + " not allowed");
106            }
107            if (v instanceof Exception) {
108                throw new IllegalArgumentException("message " + v + " not allowed");
109            }
110            TaggedMessage tm = new TaggedMessage(tag, v);
111            sc.send(tm);
112        }
113    
114    
115        /**
116         * Receive an object.
117         * @param tag message tag
118         * @return object received
119         * @throws InterruptedException
120         * @throws IOException
121         * @throws ClassNotFoundException
122         */
123        public Object receive(Integer tag) throws InterruptedException, IOException, ClassNotFoundException {
124            BlockingQueue tq = null;
125            int i = 0;
126            do {
127                synchronized (queues) {
128                    tq = queues.get(tag);
129                    if (tq == null) {
130                        if ( ! isRunning ) { // avoid dead-lock
131                            throw new IOException("receiver not running for " + this);
132                        }
133                        //tq = new LinkedBlockingQueue();
134                        //queues.put(tag, tq);
135                        try {
136                            logger.debug("receive wait, tag = " + tag);
137                            i = blockedCount.incrementAndGet();
138                            queues.wait();
139                        } catch (InterruptedException e) {
140                            logger.info("receive wait exception, tag = " + tag + ", blockedCount = " + i);
141                            throw e;
142                        } finally {
143                            i = blockedCount.decrementAndGet();
144                        }
145                    }
146                }
147            } while ( tq == null );
148            Object v = null;
149            try {
150                i = blockedCount.incrementAndGet();
151                v = tq.take();
152            } finally {
153                i = blockedCount.decrementAndGet();
154            }
155            if ( v instanceof IOException ) {
156                throw (IOException) v;
157            }
158            if ( v instanceof ClassNotFoundException ) {
159                throw (ClassNotFoundException) v;
160            }
161            if ( v instanceof Exception ) {
162                throw new RuntimeException(v.toString());
163            }
164            return v;
165        }
166    
167    
168        /**
169         * Closes the channel.
170         */
171        public void close() {
172            terminate();
173        }
174    
175    
176        /**
177         * To string.
178         * @see java.lang.Thread#toString()
179         */
180        @Override
181        public String toString() {
182            return "socketChannel(" + sc + ", tags = " + queues.keySet() + ")";
183            //return "socketChannel(" + sc + ", tags = " + queues.keySet() + ", values = " + queues.values() + ")";
184        }
185    
186    
187        /**
188         * Number of tags.
189         * @return size of key set.
190         */
191        public int tagSize() {
192            return queues.keySet().size();
193        }
194    
195    
196        /**
197         * Number of messages.
198         * @return sum of all messages in queues.
199         */
200        public int messages() {
201            int m = 0;
202            synchronized (queues) {
203                for ( BlockingQueue tq : queues.values() ) {
204                    m += tq.size();
205                }
206            }
207            return m;
208        }
209    
210    
211        /**
212         * Run receive() in an infinite loop.
213         * @see java.lang.Thread#run()
214         */
215        @Override
216        public void run() {
217            if (sc == null) {
218                isRunning = false;
219                return; // nothing to do
220            }
221            isRunning = true;
222            while (isRunning) {
223                try {
224                    Object r = null;
225                    try {
226                        logger.debug("waiting for tagged object");
227                        r = sc.receive();
228                        if (this.isInterrupted()) {
229                            //r = new InterruptedException();
230                            isRunning = false;
231                        }
232                    } catch (IOException e) {
233                        r = e;
234                    } catch (ClassNotFoundException e) {
235                        r = e;
236                    } catch (Exception e) {
237                        r = e;
238                    }
239                    //logger.debug("Socket = " +s);
240                    logger.debug("object recieved");
241                    if (r instanceof TaggedMessage) {
242                        TaggedMessage tm = (TaggedMessage) r;
243                        BlockingQueue tq = null;
244                        synchronized (queues) {
245                            tq = queues.get(tm.tag);
246                            if (tq == null) {
247                                tq = new LinkedBlockingQueue();
248                                queues.put(tm.tag, tq);
249                                queues.notifyAll();
250                            }
251                        }
252                        tq.put(tm.msg);
253                    } else if ( r instanceof Exception ){
254                        if (debug) {
255                            logger.debug("exception " + r);
256                        }
257                        synchronized (queues) { // deliver to all queues
258                            isRunning = false;
259                            for ( BlockingQueue q : queues.values() ) {
260                                final int bc = blockedCount.get();
261                                for ( int i = 0; i <= bc; i++ ) { // one more
262                                    q.put(r);
263                                }
264                                if (bc > 0) {
265                                    logger.debug("put exception to queue, blockedCount = " + bc);
266                                }
267                            }
268                            queues.notifyAll();
269                        }
270                        //return;
271                    } else {
272                        if (debug) {
273                            logger.debug("no tagged message and no exception " + r);
274                        }
275                        synchronized (queues) { // deliver to all queues
276                            isRunning = false;
277                            Exception e;
278                            if ( r.equals(DONE) ) {
279                                e = new Exception("DONE message");
280                            } else {
281                                e = new IllegalArgumentException("no tagged message and no exception '" + r + "'");
282                            }
283                            for ( BlockingQueue q : queues.values() ) {
284                                final int bc = blockedCount.get();
285                                for ( int i = 0; i <= bc; i++ ) { // one more
286                                    q.put(e);
287                                }
288                                if (bc > 0) {
289                                    logger.debug("put '" + e.toString() + "' to queue, blockedCount = " + bc);
290                                }
291                            }
292                            queues.notifyAll();
293                        }
294                        if ( r.equals(DONE) ) {
295                             logger.info("run terminating by request");
296                             try {
297                                 sc.send(DONE); // terminate other end
298                             } catch (IOException e) {
299                                 logger.warn("send other done failed " + e);
300                             }
301                             return;
302                        }
303                    }
304                } catch (InterruptedException e) {
305                    // unfug Thread.currentThread().interrupt();
306                    //logger.debug("ChannelFactory IE terminating");
307                    if (debug) {
308                        logger.debug("exception " + e);
309                    }
310                    synchronized (queues) { // deliver to all queues
311                        isRunning = false;
312                        for ( BlockingQueue q : queues.values() ) {
313                            try {
314                                final int bc = blockedCount.get();
315                                for ( int i = 0; i <= bc; i++ ) { // one more
316                                    q.put(e);
317                                }
318                                if (bc > 0) {
319                                    logger.debug("put interrupted to queue, blockCount = " + bc);
320                                }
321                            } catch (InterruptedException ignored) {
322                            }
323                        }
324                        queues.notifyAll();
325                    }
326                    //return via isRunning
327                }
328            }
329            if (this.isInterrupted()) {
330                Exception e = new InterruptedException("terminating via interrupt");
331                synchronized (queues) { // deliver to all queues
332                    for ( BlockingQueue q : queues.values() ) {
333                        try {
334                            final int bc = blockedCount.get();
335                            for ( int i = 0; i <= bc; i++ ) { // one more
336                                q.put(e);
337                            }
338                            if (bc > 0) {
339                                logger.debug("put terminating via interrupt to queue, blockCount = " + bc);
340                            }
341                        } catch (InterruptedException ignored) {
342                        }
343                    }
344                    queues.notifyAll();
345                }
346            }
347            logger.info("run terminated");
348        }
349    
350    
351        /**
352         * Terminate the TaggedSocketChannel.
353         */
354        public void terminate() {
355            isRunning = false;
356            this.interrupt();
357            if (sc != null) {
358                //sc.close();
359                try {
360                    sc.send(DONE);
361                } catch (IOException e) {
362                    logger.warn("send done failed " + e);
363                }
364                logger.debug(sc + " not yet closed");
365            }
366            this.interrupt();
367            synchronized(queues) {
368                isRunning = false;
369                for (Entry<Integer, BlockingQueue> tq : queues.entrySet()) {
370                    BlockingQueue q = tq.getValue();
371                    if (q.size() != 0) {
372                        logger.info("queue for tag " + tq.getKey() + " not empty " + q);
373                    } 
374                    int bc = 0;
375                    try {
376                        bc = blockedCount.get();
377                        for ( int i = 0; i <= bc; i++ ) { // one more
378                            q.put(new IOException("queue terminate"));
379                        }
380                    } catch (InterruptedException ignored) {
381                    }
382                    if ( bc > 0 ) {
383                        logger.debug("put IO-end to queue for tag " + tq.getKey() + ", blockCount = " + bc);
384                    }
385                }
386                queues.notifyAll();
387            }
388            try {
389                this.join();
390            } catch (InterruptedException e) {
391                // unfug Thread.currentThread().interrupt();
392            }
393            logger.info("terminated");
394        }
395    
396    }
397    
398    
399    /**
400     * TaggedMessage container.
401     * @author kredel
402     * 
403     */
404    class TaggedMessage implements Serializable {
405    
406    
407        public final Integer tag;
408    
409    
410        public final Object msg;
411    
412    
413        /**
414         * Constructor.
415         * @param tag message tag
416         * @param msg message object
417         */
418        public TaggedMessage(Integer tag, Object msg) {
419            this.tag = tag;
420            this.msg = msg;
421        }
422    
423    }