001/*
002 * $Id: TaggedSocketChannel.java 3982 2012-07-12 21:00:59Z kredel $
003 */
004
005package edu.jas.util;
006
007
008import java.io.IOException;
009import java.io.Serializable;
010import java.util.HashMap;
011import java.util.Map;
012import java.util.Map.Entry;
013import java.util.concurrent.BlockingQueue;
014import java.util.concurrent.LinkedBlockingQueue;
015import java.util.concurrent.atomic.AtomicInteger;
016
017import org.apache.log4j.Logger;
018
019
020/**
021 * TaggedSocketChannel provides a communication channel with message tags for
022 * Java objects using TCP/IP sockets.
023 * @author Heinz Kredel.
024 */
025public class TaggedSocketChannel extends Thread {
026
027
028    private static final Logger logger = Logger.getLogger(TaggedSocketChannel.class);
029
030
031    private static final boolean debug = logger.isDebugEnabled();
032
033
034    /**
035     * Flag if receiver is running.
036     */
037    private volatile boolean isRunning = false;
038
039
040    /**
041     * End message.
042     */
043    private final static String DONE = "TaggedSocketChannel Done";
044
045
046    /**
047     * Blocked threads count.
048     */
049    private final AtomicInteger blockedCount;
050
051
052    /**
053     * Underlying socket channel.
054     */
055    protected final SocketChannel sc;
056
057
058    /**
059     * Queues for each message tag.
060     */
061    protected final Map<Integer, BlockingQueue> queues;
062
063
064    /**
065     * Constructs a tagged socket channel on the given socket channel s.
066     * @param s A socket channel object.
067     */
068    public TaggedSocketChannel(SocketChannel s) {
069        sc = s;
070        blockedCount = new AtomicInteger(0);
071        queues = new HashMap<Integer, BlockingQueue>();
072    }
073
074
075    /**
076     * thread initialization and start.
077     */
078    public void init() {
079        synchronized (queues) {
080            if ( ! isRunning ) {
081                this.start();
082                isRunning = true;
083            }
084        }
085        logger.info("TaggedSocketChannel at " + sc);
086    }
087
088
089    /**
090     * Get the SocketChannel
091     */
092    public SocketChannel getSocket() {
093        return sc;
094    }
095
096
097    /**
098     * Sends an object.
099     * @param tag message tag
100     * @param v object to send
101     * @throws IOException
102     */
103    public void send(Integer tag, Object v) throws IOException {
104        if (tag == null) {
105            throw new IllegalArgumentException("tag null not allowed");
106        }
107        if (v instanceof Exception) {
108            throw new IllegalArgumentException("message " + v + " not allowed");
109        }
110        TaggedMessage tm = new TaggedMessage(tag, v);
111        sc.send(tm);
112    }
113
114
115    /**
116     * Receive an object.
117     * @param tag message tag
118     * @return object received
119     * @throws InterruptedException
120     * @throws IOException
121     * @throws ClassNotFoundException
122     */
123    public Object receive(Integer tag) throws InterruptedException, IOException, ClassNotFoundException {
124        BlockingQueue tq = null;
125        int i = 0;
126        do {
127            synchronized (queues) {
128                tq = queues.get(tag);
129                if (tq == null) {
130                    if ( ! isRunning ) { // avoid dead-lock
131                        throw new IOException("receiver not running for " + this);
132                    }
133                    //tq = new LinkedBlockingQueue();
134                    //queues.put(tag, tq);
135                    try {
136                        logger.debug("receive wait, tag = " + tag);
137                        i = blockedCount.incrementAndGet();
138                        queues.wait();
139                    } catch (InterruptedException e) {
140                        logger.info("receive wait exception, tag = " + tag + ", blockedCount = " + i);
141                        throw e;
142                    } finally {
143                        i = blockedCount.decrementAndGet();
144                    }
145                }
146            }
147        } while ( tq == null );
148        Object v = null;
149        try {
150            i = blockedCount.incrementAndGet();
151            v = tq.take();
152        } finally {
153            i = blockedCount.decrementAndGet();
154        }
155        if ( v instanceof IOException ) {
156            throw (IOException) v;
157        }
158        if ( v instanceof ClassNotFoundException ) {
159            throw (ClassNotFoundException) v;
160        }
161        if ( v instanceof Exception ) {
162            throw new RuntimeException(v.toString());
163        }
164        return v;
165    }
166
167
168    /**
169     * Closes the channel.
170     */
171    public void close() {
172        terminate();
173    }
174
175
176    /**
177     * To string.
178     * @see java.lang.Thread#toString()
179     */
180    @Override
181    public String toString() {
182        return "socketChannel(" + sc + ", tags = " + queues.keySet() + ")";
183        //return "socketChannel(" + sc + ", tags = " + queues.keySet() + ", values = " + queues.values() + ")";
184    }
185
186
187    /**
188     * Number of tags.
189     * @return size of key set.
190     */
191    public int tagSize() {
192        return queues.keySet().size();
193    }
194
195
196    /**
197     * Number of messages.
198     * @return sum of all messages in queues.
199     */
200    public int messages() {
201        int m = 0;
202        synchronized (queues) {
203            for ( BlockingQueue tq : queues.values() ) {
204                m += tq.size();
205            }
206        }
207        return m;
208    }
209
210
211    /**
212     * Run receive() in an infinite loop.
213     * @see java.lang.Thread#run()
214     */
215    @Override
216    public void run() {
217        if (sc == null) {
218            isRunning = false;
219            return; // nothing to do
220        }
221        isRunning = true;
222        while (isRunning) {
223            try {
224                Object r = null;
225                try {
226                    logger.debug("waiting for tagged object");
227                    r = sc.receive();
228                    if (this.isInterrupted()) {
229                        //r = new InterruptedException();
230                        isRunning = false;
231                    }
232                } catch (IOException e) {
233                    r = e;
234                } catch (ClassNotFoundException e) {
235                    r = e;
236                } catch (Exception e) {
237                    r = e;
238                }
239                //logger.debug("Socket = " +s);
240                logger.debug("object recieved");
241                if (r instanceof TaggedMessage) {
242                    TaggedMessage tm = (TaggedMessage) r;
243                    BlockingQueue tq = null;
244                    synchronized (queues) {
245                        tq = queues.get(tm.tag);
246                        if (tq == null) {
247                            tq = new LinkedBlockingQueue();
248                            queues.put(tm.tag, tq);
249                            queues.notifyAll();
250                        }
251                    }
252                    tq.put(tm.msg);
253                } else if ( r instanceof Exception ){
254                    if (debug) {
255                        logger.debug("exception " + r);
256                    }
257                    synchronized (queues) { // deliver to all queues
258                        isRunning = false;
259                        for ( BlockingQueue q : queues.values() ) {
260                            final int bc = blockedCount.get();
261                            for ( int i = 0; i <= bc; i++ ) { // one more
262                                q.put(r);
263                            }
264                            if (bc > 0) {
265                                logger.debug("put exception to queue, blockedCount = " + bc);
266                            }
267                        }
268                        queues.notifyAll();
269                    }
270                    //return;
271                } else {
272                    if (debug) {
273                        logger.debug("no tagged message and no exception " + r);
274                    }
275                    synchronized (queues) { // deliver to all queues
276                        isRunning = false;
277                        Exception e;
278                        if ( r.equals(DONE) ) {
279                            e = new Exception("DONE message");
280                        } else {
281                            e = new IllegalArgumentException("no tagged message and no exception '" + r + "'");
282                        }
283                        for ( BlockingQueue q : queues.values() ) {
284                            final int bc = blockedCount.get();
285                            for ( int i = 0; i <= bc; i++ ) { // one more
286                                q.put(e);
287                            }
288                            if (bc > 0) {
289                                logger.debug("put '" + e.toString() + "' to queue, blockedCount = " + bc);
290                            }
291                        }
292                        queues.notifyAll();
293                    }
294                    if ( r.equals(DONE) ) {
295                         logger.info("run terminating by request");
296                         try {
297                             sc.send(DONE); // terminate other end
298                         } catch (IOException e) {
299                             logger.warn("send other done failed " + e);
300                         }
301                         return;
302                    }
303                }
304            } catch (InterruptedException e) {
305                // unfug Thread.currentThread().interrupt();
306                //logger.debug("ChannelFactory IE terminating");
307                if (debug) {
308                    logger.debug("exception " + e);
309                }
310                synchronized (queues) { // deliver to all queues
311                    isRunning = false;
312                    for ( BlockingQueue q : queues.values() ) {
313                        try {
314                            final int bc = blockedCount.get();
315                            for ( int i = 0; i <= bc; i++ ) { // one more
316                                q.put(e);
317                            }
318                            if (bc > 0) {
319                                logger.debug("put interrupted to queue, blockCount = " + bc);
320                            }
321                        } catch (InterruptedException ignored) {
322                        }
323                    }
324                    queues.notifyAll();
325                }
326                //return via isRunning
327            }
328        }
329        if (this.isInterrupted()) {
330            Exception e = new InterruptedException("terminating via interrupt");
331            synchronized (queues) { // deliver to all queues
332                for ( BlockingQueue q : queues.values() ) {
333                    try {
334                        final int bc = blockedCount.get();
335                        for ( int i = 0; i <= bc; i++ ) { // one more
336                            q.put(e);
337                        }
338                        if (bc > 0) {
339                            logger.debug("put terminating via interrupt to queue, blockCount = " + bc);
340                        }
341                    } catch (InterruptedException ignored) {
342                    }
343                }
344                queues.notifyAll();
345            }
346        }
347        logger.info("run terminated");
348    }
349
350
351    /**
352     * Terminate the TaggedSocketChannel.
353     */
354    public void terminate() {
355        isRunning = false;
356        this.interrupt();
357        if (sc != null) {
358            //sc.close();
359            try {
360                sc.send(DONE);
361            } catch (IOException e) {
362                logger.warn("send done failed " + e);
363            }
364            logger.debug(sc + " not yet closed");
365        }
366        this.interrupt();
367        synchronized(queues) {
368            isRunning = false;
369            for (Entry<Integer, BlockingQueue> tq : queues.entrySet()) {
370                BlockingQueue q = tq.getValue();
371                if (q.size() != 0) {
372                    logger.info("queue for tag " + tq.getKey() + " not empty " + q);
373                } 
374                int bc = 0;
375                try {
376                    bc = blockedCount.get();
377                    for ( int i = 0; i <= bc; i++ ) { // one more
378                        q.put(new IOException("queue terminate"));
379                    }
380                } catch (InterruptedException ignored) {
381                }
382                if ( bc > 0 ) {
383                    logger.debug("put IO-end to queue for tag " + tq.getKey() + ", blockCount = " + bc);
384                }
385            }
386            queues.notifyAll();
387        }
388        try {
389            this.join();
390        } catch (InterruptedException e) {
391            // unfug Thread.currentThread().interrupt();
392        }
393        logger.info("terminated");
394    }
395
396}
397
398
399/**
400 * TaggedMessage container.
401 * @author kredel
402 * 
403 */
404class TaggedMessage implements Serializable {
405
406
407    public final Integer tag;
408
409
410    public final Object msg;
411
412
413    /**
414     * Constructor.
415     * @param tag message tag
416     * @param msg message object
417     */
418    public TaggedMessage(Integer tag, Object msg) {
419        this.tag = tag;
420        this.msg = msg;
421    }
422
423}