001/*
002 * $Id$
003 */
004
005package edu.jas.util;
006
007
008import java.io.IOException;
009import java.io.Serializable;
010import java.util.HashMap;
011import java.util.Map;
012import java.util.Map.Entry;
013import java.util.concurrent.BlockingQueue;
014import java.util.concurrent.LinkedBlockingQueue;
015import java.util.concurrent.atomic.AtomicInteger;
016
017import org.apache.logging.log4j.Logger;
018import org.apache.logging.log4j.LogManager; 
019
020
021/**
022 * TaggedSocketChannel provides a communication channel with message tags for
023 * Java objects using TCP/IP sockets.
024 * @author Heinz Kredel
025 */
026public class TaggedSocketChannel extends Thread {
027
028
029    private static final Logger logger = LogManager.getLogger(TaggedSocketChannel.class);
030
031
032    private static final boolean debug = logger.isDebugEnabled();
033
034
035    /**
036     * Flag if receiver is running.
037     */
038    private volatile boolean isRunning = false;
039
040
041    /**
042     * End message.
043     */
044    private final static String DONE = "TaggedSocketChannel Done";
045
046
047    /**
048     * Blocked threads count.
049     */
050    private final AtomicInteger blockedCount;
051
052
053    /**
054     * Underlying socket channel.
055     */
056    protected final SocketChannel sc;
057
058
059    /**
060     * Queues for each message tag.
061     */
062    protected final Map<Integer, BlockingQueue> queues;
063
064
065    /**
066     * Constructs a tagged socket channel on the given socket channel s.
067     * @param s A socket channel object.
068     */
069    public TaggedSocketChannel(SocketChannel s) {
070        sc = s;
071        blockedCount = new AtomicInteger(0);
072        queues = new HashMap<Integer, BlockingQueue>();
073    }
074
075
076    /**
077     * thread initialization and start.
078     */
079    public void init() {
080        synchronized (queues) {
081            if ( ! isRunning ) {
082                this.start();
083                isRunning = true;
084            }
085        }
086        logger.info("TaggedSocketChannel at {}", sc);
087    }
088
089
090    /**
091     * Get the SocketChannel
092     */
093    public SocketChannel getSocket() {
094        return sc;
095    }
096
097
098    /**
099     * Sends an object.
100     * @param tag message tag
101     * @param v object to send
102     * @throws IOException
103     */
104    public void send(Integer tag, Object v) throws IOException {
105        if (tag == null) {
106            throw new IllegalArgumentException("tag null not allowed");
107        }
108        if (v instanceof Exception) {
109            throw new IllegalArgumentException("message " + v + " not allowed");
110        }
111        TaggedMessage tm = new TaggedMessage(tag, v);
112        sc.send(tm);
113    }
114
115
116    /**
117     * Receive an object.
118     * @param tag message tag
119     * @return object received
120     * @throws InterruptedException
121     * @throws IOException
122     * @throws ClassNotFoundException
123     */
124    public Object receive(Integer tag) throws InterruptedException, IOException, ClassNotFoundException {
125        BlockingQueue tq = null;
126        int i = 0;
127        do {
128            synchronized (queues) {
129                tq = queues.get(tag);
130                if (tq == null) {
131                    if ( ! isRunning ) { // avoid dead-lock
132                        throw new IOException("receiver not running for " + this);
133                    }
134                    //tq = new LinkedBlockingQueue();
135                    //queues.put(tag, tq);
136                    try {
137                        logger.debug("receive wait, tag = {}", tag);
138                        i = blockedCount.incrementAndGet();
139                        queues.wait();
140                    } catch (InterruptedException e) {
141                        logger.info("receive wait exception, tag = {}, blockedCount = {}", tag, i);
142                        throw e;
143                    } finally {
144                        i = blockedCount.decrementAndGet();
145                    }
146                }
147            }
148        } while ( tq == null );
149        Object v = null;
150        try {
151            i = blockedCount.incrementAndGet();
152            v = tq.take();
153        } finally {
154            i = blockedCount.decrementAndGet();
155        }
156        if ( v instanceof IOException ) {
157            throw (IOException) v;
158        }
159        if ( v instanceof ClassNotFoundException ) {
160            throw (ClassNotFoundException) v;
161        }
162        if ( v instanceof Exception ) {
163            //throw new RuntimeException(v.toString());
164            throw new IOException(v.toString());
165        }
166        return v;
167    }
168
169
170    /**
171     * Closes the channel.
172     */
173    public void close() {
174        terminate();
175    }
176
177
178    /**
179     * To string.
180     * @see java.lang.Thread#toString()
181     */
182    @Override
183    public String toString() {
184        return "socketChannel(" + sc + ", tags = " + queues.keySet() + ")";
185        //return "socketChannel(" + sc + ", tags = " + queues.keySet() + ", values = " + queues.values() + ")";
186    }
187
188
189    /**
190     * Number of tags.
191     * @return size of key set.
192     */
193    public int tagSize() {
194        return queues.size();
195    }
196
197
198    /**
199     * Number of messages.
200     * @return sum of all messages in queues.
201     */
202    public int messages() {
203        int m = 0;
204        synchronized (queues) {
205            for ( BlockingQueue tq : queues.values() ) {
206                m += tq.size();
207            }
208        }
209        return m;
210    }
211
212
213    /**
214     * Run receive() in an infinite loop.
215     * @see java.lang.Thread#run()
216     */
217    @Override
218    @SuppressWarnings("unchecked")
219    public void run() {
220        if (sc == null) {
221            isRunning = false;
222            return; // nothing to do
223        }
224        isRunning = true;
225        while (isRunning) {
226            try {
227                Object r = null;
228                try {
229                    logger.debug("waiting for tagged object");
230                    r = sc.receive();
231                    if (this.isInterrupted()) {
232                        //r = new InterruptedException();
233                        isRunning = false;
234                    }
235                } catch (IOException e) {
236                    r = e;
237                } catch (ClassNotFoundException e) {
238                    r = e;
239                } catch (Exception e) {
240                    r = e;
241                }
242                //logger.debug("Socket = {}", s);
243                logger.debug("object received");
244                if (r instanceof TaggedMessage) {
245                    TaggedMessage tm = (TaggedMessage) r;
246                    BlockingQueue tq = null;
247                    synchronized (queues) {
248                        tq = queues.get(tm.tag);
249                        if (tq == null) {
250                            tq = new LinkedBlockingQueue();
251                            queues.put(tm.tag, tq);
252                            queues.notifyAll();
253                        }
254                    }
255                    tq.put(tm.msg);
256                } else if ( r instanceof Exception ){
257                    if (debug) {
258                        logger.debug("exception {}", r);
259                    }
260                    synchronized (queues) { // deliver to all queues
261                        isRunning = false;
262                        for ( BlockingQueue q : queues.values() ) {
263                            final int bc = blockedCount.get();
264                            for ( int i = 0; i <= bc; i++ ) { // one more
265                                q.put(r);
266                            }
267                            if (bc > 0) {
268                                logger.debug("put exception to queue, blockedCount = {}", bc);
269                            }
270                        }
271                        queues.notifyAll();
272                    }
273                    //return;
274                } else {
275                    if (debug) {
276                        logger.debug("no tagged message and no exception {}", r);
277                    }
278                    synchronized (queues) { // deliver to all queues
279                        isRunning = false;
280                        Exception e;
281                        if ( r.equals(DONE) ) {
282                            e = new Exception("DONE message");
283                        } else {
284                            e = new IllegalArgumentException("no tagged message and no exception '" + r + "'");
285                        }
286                        for ( BlockingQueue q : queues.values() ) {
287                            final int bc = blockedCount.get();
288                            for ( int i = 0; i <= bc; i++ ) { // one more
289                                q.put(e);
290                            }
291                            if (bc > 0) {
292                                logger.debug("put '{}' to queue, blockedCount = {}", e, bc);
293                            }
294                        }
295                        queues.notifyAll();
296                    }
297                    if ( r.equals(DONE) ) {
298                         logger.info("run terminating by request");
299                         try {
300                             sc.send(DONE); // terminate other end
301                         } catch (IOException e) {
302                             logger.warn("send other done failed {}", e);
303                         }
304                         return;
305                    }
306                }
307            } catch (InterruptedException e) {
308                // unfug Thread.currentThread().interrupt();
309                //logger.debug("ChannelFactory IE terminating");
310                if (debug) {
311                    logger.debug("exception {}", e);
312                }
313                synchronized (queues) { // deliver to all queues
314                    isRunning = false;
315                    for ( BlockingQueue q : queues.values() ) {
316                        try {
317                            final int bc = blockedCount.get();
318                            for ( int i = 0; i <= bc; i++ ) { // one more
319                                q.put(e);
320                            }
321                            if (bc > 0) {
322                                logger.debug("put interrupted to queue, blockCount = {}", bc);
323                            }
324                        } catch (InterruptedException ignored) {
325                        }
326                    }
327                    queues.notifyAll();
328                }
329                //return via isRunning
330            }
331        }
332        if (this.isInterrupted()) {
333            Exception e = new InterruptedException("terminating via interrupt");
334            synchronized (queues) { // deliver to all queues
335                for ( BlockingQueue q : queues.values() ) {
336                    try {
337                        final int bc = blockedCount.get();
338                        for ( int i = 0; i <= bc; i++ ) { // one more
339                            q.put(e);
340                        }
341                        if (bc > 0) {
342                            logger.debug("put terminating via interrupt to queue, blockCount = {}", bc);
343                        }
344                    } catch (InterruptedException ignored) {
345                    }
346                }
347                queues.notifyAll();
348            }
349        }
350        logger.info("run terminated");
351    }
352
353
354    /**
355     * Terminate the TaggedSocketChannel.
356     */
357    @SuppressWarnings("unchecked")
358    public void terminate() {
359        isRunning = false;
360        this.interrupt();
361        if (sc != null) {
362            //sc.close();
363            try {
364                sc.send(DONE);
365            } catch (IOException e) {
366                logger.warn("send done failed {}", e);
367            }
368            logger.debug("{} not yet closed", sc);
369        }
370        this.interrupt();
371        synchronized(queues) {
372            isRunning = false;
373            for (Entry<Integer, BlockingQueue> tq : queues.entrySet()) {
374                BlockingQueue q = tq.getValue();
375                if (q.size() != 0) {
376                    logger.info("queue for tag {} not empty {}", tq.getKey(), q);
377                } 
378                int bc = 0;
379                try {
380                    bc = blockedCount.get();
381                    for ( int i = 0; i <= bc; i++ ) { // one more
382                        q.put(new IOException("queue terminate"));
383                    }
384                } catch (InterruptedException ignored) {
385                }
386                if ( bc > 0 ) {
387                    logger.debug("put IO-end to queue for tag {}, blockCount = {}", tq.getKey(), bc);
388                }
389            }
390            queues.notifyAll();
391        }
392        try {
393            this.join();
394        } catch (InterruptedException e) {
395            // unfug Thread.currentThread().interrupt();
396        }
397        logger.info("terminated");
398    }
399
400}
401
402
403/**
404 * TaggedMessage container.
405 */
406class TaggedMessage implements Serializable {
407
408
409    public final Integer tag;
410
411
412    public final Object msg;
413
414
415    /**
416     * Constructor.
417     * @param tag message tag
418     * @param msg message object
419     */
420    public TaggedMessage(Integer tag, Object msg) {
421        this.tag = tag;
422        this.msg = msg;
423    }
424
425}