001/*
002 * $Id: TaggedSocketChannel.java 4944 2014-10-05 18:35:23Z axelclk $
003 */
004
005package edu.jas.util;
006
007
008import java.io.IOException;
009import java.io.Serializable;
010import java.util.HashMap;
011import java.util.Map;
012import java.util.Map.Entry;
013import java.util.concurrent.BlockingQueue;
014import java.util.concurrent.LinkedBlockingQueue;
015import java.util.concurrent.atomic.AtomicInteger;
016
017import org.apache.log4j.Logger;
018
019
020/**
021 * TaggedSocketChannel provides a communication channel with message tags for
022 * Java objects using TCP/IP sockets.
023 * @author Heinz Kredel.
024 */
025public class TaggedSocketChannel extends Thread {
026
027
028    private static final Logger logger = Logger.getLogger(TaggedSocketChannel.class);
029
030
031    private static final boolean debug = logger.isDebugEnabled();
032
033
034    /**
035     * Flag if receiver is running.
036     */
037    private volatile boolean isRunning = false;
038
039
040    /**
041     * End message.
042     */
043    private final static String DONE = "TaggedSocketChannel Done";
044
045
046    /**
047     * Blocked threads count.
048     */
049    private final AtomicInteger blockedCount;
050
051
052    /**
053     * Underlying socket channel.
054     */
055    protected final SocketChannel sc;
056
057
058    /**
059     * Queues for each message tag.
060     */
061    protected final Map<Integer, BlockingQueue> queues;
062
063
064    /**
065     * Constructs a tagged socket channel on the given socket channel s.
066     * @param s A socket channel object.
067     */
068    public TaggedSocketChannel(SocketChannel s) {
069        sc = s;
070        blockedCount = new AtomicInteger(0);
071        queues = new HashMap<Integer, BlockingQueue>();
072    }
073
074
075    /**
076     * thread initialization and start.
077     */
078    public void init() {
079        synchronized (queues) {
080            if ( ! isRunning ) {
081                this.start();
082                isRunning = true;
083            }
084        }
085        logger.info("TaggedSocketChannel at " + sc);
086    }
087
088
089    /**
090     * Get the SocketChannel
091     */
092    public SocketChannel getSocket() {
093        return sc;
094    }
095
096
097    /**
098     * Sends an object.
099     * @param tag message tag
100     * @param v object to send
101     * @throws IOException
102     */
103    public void send(Integer tag, Object v) throws IOException {
104        if (tag == null) {
105            throw new IllegalArgumentException("tag null not allowed");
106        }
107        if (v instanceof Exception) {
108            throw new IllegalArgumentException("message " + v + " not allowed");
109        }
110        TaggedMessage tm = new TaggedMessage(tag, v);
111        sc.send(tm);
112    }
113
114
115    /**
116     * Receive an object.
117     * @param tag message tag
118     * @return object received
119     * @throws InterruptedException
120     * @throws IOException
121     * @throws ClassNotFoundException
122     */
123    public Object receive(Integer tag) throws InterruptedException, IOException, ClassNotFoundException {
124        BlockingQueue tq = null;
125        int i = 0;
126        do {
127            synchronized (queues) {
128                tq = queues.get(tag);
129                if (tq == null) {
130                    if ( ! isRunning ) { // avoid dead-lock
131                        throw new IOException("receiver not running for " + this);
132                    }
133                    //tq = new LinkedBlockingQueue();
134                    //queues.put(tag, tq);
135                    try {
136                        logger.debug("receive wait, tag = " + tag);
137                        i = blockedCount.incrementAndGet();
138                        queues.wait();
139                    } catch (InterruptedException e) {
140                        logger.info("receive wait exception, tag = " + tag + ", blockedCount = " + i);
141                        throw e;
142                    } finally {
143                        i = blockedCount.decrementAndGet();
144                    }
145                }
146            }
147        } while ( tq == null );
148        Object v = null;
149        try {
150            i = blockedCount.incrementAndGet();
151            v = tq.take();
152        } finally {
153            i = blockedCount.decrementAndGet();
154        }
155        if ( v instanceof IOException ) {
156            throw (IOException) v;
157        }
158        if ( v instanceof ClassNotFoundException ) {
159            throw (ClassNotFoundException) v;
160        }
161        if ( v instanceof Exception ) {
162            //throw new RuntimeException(v.toString());
163            throw new IOException(v.toString());
164        }
165        return v;
166    }
167
168
169    /**
170     * Closes the channel.
171     */
172    public void close() {
173        terminate();
174    }
175
176
177    /**
178     * To string.
179     * @see java.lang.Thread#toString()
180     */
181    @Override
182    public String toString() {
183        return "socketChannel(" + sc + ", tags = " + queues.keySet() + ")";
184        //return "socketChannel(" + sc + ", tags = " + queues.keySet() + ", values = " + queues.values() + ")";
185    }
186
187
188    /**
189     * Number of tags.
190     * @return size of key set.
191     */
192    public int tagSize() {
193        return queues.size();
194    }
195
196
197    /**
198     * Number of messages.
199     * @return sum of all messages in queues.
200     */
201    public int messages() {
202        int m = 0;
203        synchronized (queues) {
204            for ( BlockingQueue tq : queues.values() ) {
205                m += tq.size();
206            }
207        }
208        return m;
209    }
210
211
212    /**
213     * Run receive() in an infinite loop.
214     * @see java.lang.Thread#run()
215     */
216    @Override
217    @SuppressWarnings("unchecked")
218    public void run() {
219        if (sc == null) {
220            isRunning = false;
221            return; // nothing to do
222        }
223        isRunning = true;
224        while (isRunning) {
225            try {
226                Object r = null;
227                try {
228                    logger.debug("waiting for tagged object");
229                    r = sc.receive();
230                    if (this.isInterrupted()) {
231                        //r = new InterruptedException();
232                        isRunning = false;
233                    }
234                } catch (IOException e) {
235                    r = e;
236                } catch (ClassNotFoundException e) {
237                    r = e;
238                } catch (Exception e) {
239                    r = e;
240                }
241                //logger.debug("Socket = " +s);
242                logger.debug("object recieved");
243                if (r instanceof TaggedMessage) {
244                    TaggedMessage tm = (TaggedMessage) r;
245                    BlockingQueue tq = null;
246                    synchronized (queues) {
247                        tq = queues.get(tm.tag);
248                        if (tq == null) {
249                            tq = new LinkedBlockingQueue();
250                            queues.put(tm.tag, tq);
251                            queues.notifyAll();
252                        }
253                    }
254                    tq.put(tm.msg);
255                } else if ( r instanceof Exception ){
256                    if (debug) {
257                        logger.debug("exception " + r);
258                    }
259                    synchronized (queues) { // deliver to all queues
260                        isRunning = false;
261                        for ( BlockingQueue q : queues.values() ) {
262                            final int bc = blockedCount.get();
263                            for ( int i = 0; i <= bc; i++ ) { // one more
264                                q.put(r);
265                            }
266                            if (bc > 0) {
267                                logger.debug("put exception to queue, blockedCount = " + bc);
268                            }
269                        }
270                        queues.notifyAll();
271                    }
272                    //return;
273                } else {
274                    if (debug) {
275                        logger.debug("no tagged message and no exception " + r);
276                    }
277                    synchronized (queues) { // deliver to all queues
278                        isRunning = false;
279                        Exception e;
280                        if ( r.equals(DONE) ) {
281                            e = new Exception("DONE message");
282                        } else {
283                            e = new IllegalArgumentException("no tagged message and no exception '" + r + "'");
284                        }
285                        for ( BlockingQueue q : queues.values() ) {
286                            final int bc = blockedCount.get();
287                            for ( int i = 0; i <= bc; i++ ) { // one more
288                                q.put(e);
289                            }
290                            if (bc > 0) {
291                                logger.debug("put '" + e.toString() + "' to queue, blockedCount = " + bc);
292                            }
293                        }
294                        queues.notifyAll();
295                    }
296                    if ( r.equals(DONE) ) {
297                         logger.info("run terminating by request");
298                         try {
299                             sc.send(DONE); // terminate other end
300                         } catch (IOException e) {
301                             logger.warn("send other done failed " + e);
302                         }
303                         return;
304                    }
305                }
306            } catch (InterruptedException e) {
307                // unfug Thread.currentThread().interrupt();
308                //logger.debug("ChannelFactory IE terminating");
309                if (debug) {
310                    logger.debug("exception " + e);
311                }
312                synchronized (queues) { // deliver to all queues
313                    isRunning = false;
314                    for ( BlockingQueue q : queues.values() ) {
315                        try {
316                            final int bc = blockedCount.get();
317                            for ( int i = 0; i <= bc; i++ ) { // one more
318                                q.put(e);
319                            }
320                            if (bc > 0) {
321                                logger.debug("put interrupted to queue, blockCount = " + bc);
322                            }
323                        } catch (InterruptedException ignored) {
324                        }
325                    }
326                    queues.notifyAll();
327                }
328                //return via isRunning
329            }
330        }
331        if (this.isInterrupted()) {
332            Exception e = new InterruptedException("terminating via interrupt");
333            synchronized (queues) { // deliver to all queues
334                for ( BlockingQueue q : queues.values() ) {
335                    try {
336                        final int bc = blockedCount.get();
337                        for ( int i = 0; i <= bc; i++ ) { // one more
338                            q.put(e);
339                        }
340                        if (bc > 0) {
341                            logger.debug("put terminating via interrupt to queue, blockCount = " + bc);
342                        }
343                    } catch (InterruptedException ignored) {
344                    }
345                }
346                queues.notifyAll();
347            }
348        }
349        logger.info("run terminated");
350    }
351
352
353    /**
354     * Terminate the TaggedSocketChannel.
355     */
356    @SuppressWarnings("unchecked")
357    public void terminate() {
358        isRunning = false;
359        this.interrupt();
360        if (sc != null) {
361            //sc.close();
362            try {
363                sc.send(DONE);
364            } catch (IOException e) {
365                logger.warn("send done failed " + e);
366            }
367            logger.debug(sc + " not yet closed");
368        }
369        this.interrupt();
370        synchronized(queues) {
371            isRunning = false;
372            for (Entry<Integer, BlockingQueue> tq : queues.entrySet()) {
373                BlockingQueue q = tq.getValue();
374                if (q.size() != 0) {
375                    logger.info("queue for tag " + tq.getKey() + " not empty " + q);
376                } 
377                int bc = 0;
378                try {
379                    bc = blockedCount.get();
380                    for ( int i = 0; i <= bc; i++ ) { // one more
381                        q.put(new IOException("queue terminate"));
382                    }
383                } catch (InterruptedException ignored) {
384                }
385                if ( bc > 0 ) {
386                    logger.debug("put IO-end to queue for tag " + tq.getKey() + ", blockCount = " + bc);
387                }
388            }
389            queues.notifyAll();
390        }
391        try {
392            this.join();
393        } catch (InterruptedException e) {
394            // unfug Thread.currentThread().interrupt();
395        }
396        logger.info("terminated");
397    }
398
399}
400
401
402/**
403 * TaggedMessage container.
404 * @author kredel
405 * 
406 */
407class TaggedMessage implements Serializable {
408
409
410    public final Integer tag;
411
412
413    public final Object msg;
414
415
416    /**
417     * Constructor.
418     * @param tag message tag
419     * @param msg message object
420     */
421    public TaggedMessage(Integer tag, Object msg) {
422        this.tag = tag;
423        this.msg = msg;
424    }
425
426}