Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
  /*
   *  Licensed to the Apache Software Foundation (ASF) under one or more
   *  contributor license agreements.  See the NOTICE file distributed with
   *  this work for additional information regarding copyright ownership.
   *  The ASF licenses this file to You under the Apache License, Version 2.0
   *  (the "License"); you may not use this file except in compliance with
   *  the License.  You may obtain a copy of the License at
   *
   *      http://www.apache.org/licenses/LICENSE-2.0
  *
  *  Unless required by applicable law or agreed to in writing, software
  *  distributed under the License is distributed on an "AS IS" BASIS,
  *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  *  See the License for the specific language governing permissions and
  *  limitations under the License.
  */
 package org.apache.tomcat.websocket;
 
 import static org.jboss.web.WebsocketsMessages.MESSAGES;
 
 import  java.nio.channels.AsynchronousSocketChannel;
 import  java.nio.channels.CompletionHandler;
 
 
Wraps the AsynchronousSocketChannel with SSL/TLS. This needs a lot more testing before it can be considered robust.
 
 public class AsyncChannelWrapperSecure implements AsyncChannelWrapper {
 
     private static final ByteBuffer DUMMY = ByteBuffer.allocate(8192);
     private final AsynchronousSocketChannel socketChannel;
     private final SSLEngine sslEngine;
     private final ByteBuffer socketReadBuffer;
     private final ByteBuffer socketWriteBuffer;
     // One thread for read, one for write
     private final ExecutorService executor =
             Executors.newFixedThreadPool(2, new SecureIOThreadFactory());
     private AtomicBoolean writing = new AtomicBoolean(false);
     private AtomicBoolean reading = new AtomicBoolean(false);
 
     public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel,
             SSLEngine sslEngine) {
         this. = socketChannel;
         this. = sslEngine;
 
         int socketBufferSize = sslEngine.getSession().getPacketBufferSize();
          = ByteBuffer.allocateDirect(socketBufferSize);
          = ByteBuffer.allocateDirect(socketBufferSize);
     }
 
     @Override
     public Future<Integerread(ByteBuffer dst) {
         WrapperFuture<Integer,Voidfuture = new WrapperFuture<IntegerVoid>();
 
         if (!.compareAndSet(falsetrue)) {
             throw .invalidConcurrentRead();
         }
 
         ReadTask readTask = new ReadTask(dstfuture);
 
         .execute(readTask);
 
         return future;
     }
 
     @Override
     public <B,A extends B> void read(ByteBuffer dst, A attachment,
             CompletionHandler<Integer,B> handler) {
 
         WrapperFuture<Integer,B> future =
                 new WrapperFuture<Integer, B>(handlerattachment);
 
         if (!.compareAndSet(falsetrue)) {
             throw .invalidConcurrentRead();
         }
 
         ReadTask readTask = new ReadTask(dstfuture);
 
        .execute(readTask);
    }
    @Override
    public Future<Integerwrite(ByteBuffer src) {
        WrapperFuture<Long,Voidinner = new WrapperFuture<LongVoid>();
        if (!.compareAndSet(falsetrue)) {
            throw .invalidConcurrentWrite();
        }
        WriteTask writeTask =
                new WriteTask(new ByteBuffer[] {src}, 0, 1, inner);
        .execute(writeTask);
        Future<Integerfuture = new LongToIntegerFuture(inner);
        return future;
    }
    @Override
    public <B,A extends B> void write(ByteBuffer[] srcsint offsetint length,
            long timeoutTimeUnit unit, A attachment,
            CompletionHandler<Long,B> handler) {
        WrapperFuture<Long,B> future =
                new WrapperFuture<Long, B>(handlerattachment);
        if (!.compareAndSet(falsetrue)) {
            throw .invalidConcurrentWrite();
        }
        WriteTask writeTask = new WriteTask(srcsoffsetlengthfuture);
        .execute(writeTask);
    }
    @Override
    public void close() {
        try {
            .close();
        } catch (IOException e) {
            ..errorClose();
        }
        .shutdownNow();
    }
    @Override
    public Future<Voidhandshake() throws SSLException {
        WrapperFuture<Void,VoidwFuture = new WrapperFuture<VoidVoid>();
        Thread t = new WebSocketSslHandshakeThread(wFuture);
        t.start();
        return wFuture;
    }
    private class WriteTask implements Runnable {
        private final ByteBuffer[] srcs;
        private final int offset;
        private final int length;
        private final WrapperFuture<Long,?> future;
        public WriteTask(ByteBuffer[] srcsint offsetint length,
                WrapperFuture<Long,?> future) {
            this. = srcs;
            this. = future;
            this. = offset;
            this. = length;
        }
        @Override
        public void run() {
            long written = 0;
            try {
                for (int i = i <  + i++) {
                    ByteBuffer src = [i];
                    while (src.hasRemaining()) {
                        .clear();
                        // Encrypt the data
                        SSLEngineResult r = .wrap(src);
                        written += r.bytesConsumed();
                        Status s = r.getStatus();
                        if (s == . || s == .) {
                            // Need to write out the bytes and may need to read from
                            // the source again to empty it
                        } else {
                            // Status.BUFFER_UNDERFLOW - only happens on unwrap
                            // Status.CLOSED - unexpected
                            throw .unexpectedStatusAfterWrap();
                        }
                        // Check for tasks
                        if (r.getHandshakeStatus() == .) {
                            Runnable runnable = .getDelegatedTask();
                            while (runnable != null) {
                                runnable.run();
                                runnable = .getDelegatedTask();
                            }
                        }
                        .flip();
                        // Do the write
                        int toWrite = r.bytesProduced();
                        while (toWrite > 0) {
                            Future<Integerf =
                                    .write();
                            Integer socketWrite = f.get();
                            toWrite -= socketWrite.intValue();
                        }
                    }
                }
                if (.compareAndSet(truefalse)) {
                    .complete(Long.valueOf(written));
                } else {
                    .fail(.invalidWriteState());
                }
            } catch (Exception e) {
                .fail(e);
            }
        }
    }
    private class ReadTask implements Runnable {
        private final ByteBuffer dest;
        private final WrapperFuture<Integer,?> future;
        public ReadTask(ByteBuffer destWrapperFuture<Integer,?> future) {
            this. = dest;
            this. = future;
        }
        @Override
        public void run() {
            int read = 0;
            boolean forceRead = false;
            try {
                while (read == 0) {
                    .compact();
                    if (forceRead) {
                        Future<Integerf =
                                .read();
                        Integer socketRead = f.get();
                        if (socketRead.intValue() == -1) {
                            throw new EOFException(.unexpectedEndOfStream());
                        }
                    }
                    .flip();
                    if (.hasRemaining()) {
                        // Decrypt the data in the buffer
                        SSLEngineResult r =
                                .unwrap();
                        read += r.bytesProduced();
                        Status s = r.getStatus();
                        if (s == .) {
                            // Bytes available for reading and there may be
                            // sufficient data in the socketReadBuffer to
                            // support further reads without reading from the
                            // socket
                        } else if (s == .) {
                            // There is partial data in the socketReadBuffer
                            if (read == 0) {
                                // Need more data before the partial data can be
                                // processed and some output generated
                                forceRead = true;
                            }
                            // else return the data we have and deal with the
                            // partial data on the next read
                        } else if (s == .) {
                            // Not enough space in the destination buffer to
                            // store all of the data. We could use a bytes read
                            // value of -bufferSizeRequired to signal the new
                            // buffer size required but an explicit exception is
                            // clearer.
                            if (.compareAndSet(truefalse)) {
                                throw new ReadBufferOverflowException(.
                                        getSession().getApplicationBufferSize());
                            } else {
                                .fail(.invalidReadState());
                            }
                        } else {
                            // Status.CLOSED - unexpected
                            throw .unexpectedStatusAfterUnwrap();
                        }
                        // Check for tasks
                        if (r.getHandshakeStatus() == .) {
                            Runnable runnable = .getDelegatedTask();
                            while (runnable != null) {
                                runnable.run();
                                runnable = .getDelegatedTask();
                            }
                        }
                    } else {
                        forceRead = true;
                    }
                }
                if (.compareAndSet(truefalse)) {
                    .complete(Integer.valueOf(read));
                } else {
                    .fail(.invalidReadState());
                }
            } catch (Exception e) {
                .fail(e);
            }
        }
    }
    private class WebSocketSslHandshakeThread extends Thread {
        private final WrapperFuture<Void,VoidhFuture;
        private HandshakeStatus handshakeStatus;
        private Status resultStatus;
        public WebSocketSslHandshakeThread(WrapperFuture<Void,VoidhFuture) {
            this. = hFuture;
        }
        @Override
        public void run() {
            try {
                .beginHandshake();
                // So the first compact does the right thing
                .position(.limit());
                 = .getHandshakeStatus();
                 = .;
                boolean handshaking = true;
                while(handshaking) {
                    switch () {
                        case : {
                            .clear();
                            SSLEngineResult r =
                                    .wrap();
                            checkResult(rtrue);
                            .flip();
                            Future<IntegerfWrite =
                                    .write();
                            fWrite.get();
                            break;
                        }
                        case : {
                            .compact();
                            if (.position() == 0 ||
                                     == .) {
                                Future<IntegerfRead =
                                        .read();
                                fRead.get();
                            }
                            .flip();
                            SSLEngineResult r =
                                    .unwrap();
                            checkResult(rfalse);
                            break;
                        }
                        case : {
                            Runnable r = null;
                            while ((r = .getDelegatedTask()) != null) {
                                r.run();
                            }
                             = .getHandshakeStatus();
                            break;
                        }
                        case : {
                            handshaking = false;
                            break;
                        }
                        default: {
                            throw new SSLException("TODO");
                        }
                    }
                }
            } catch (SSLException e) {
                .fail(e);
            } catch (InterruptedException e) {
                .fail(e);
            } catch (ExecutionException e) {
                .fail(e);
            }
            .complete(null);
        }
        private void checkResult(SSLEngineResult resultboolean wrap)
                throws SSLException {
             = result.getHandshakeStatus();
             = result.getStatus();
            if ( != . &&
                    (wrap ||  != .)) {
                throw new SSLException("TODO");
            }
            if (wrap && result.bytesConsumed() != 0) {
                throw new SSLException("TODO");
            }
            if (!wrap && result.bytesProduced() != 0) {
                throw new SSLException("TODO");
            }
        }
    }
    private static class WrapperFuture<T,A> implements Future<T> {
        private final CompletionHandler<T,A> handler;
        private final A attachment;
        private volatile T result = null;
        private volatile Throwable throwable = null;
        private CountDownLatch completionLatch = new CountDownLatch(1);
        public WrapperFuture() {
            this(nullnull);
        }
        public WrapperFuture(CompletionHandler<T,A> handler, A attachment) {
            this. = handler;
            this. = attachment;
        }
        public void complete(T result) {
            this. = result;
            .countDown();
            if ( != null) {
                .completed(result);
            }
        }
        public void fail(Throwable t) {
             = t;
            .countDown();
            if ( != null) {
                .failed();
            }
        }
        @Override
        public final boolean cancel(boolean mayInterruptIfRunning) {
            // Could support cancellation by closing the connection
            return false;
        }
        @Override
        public final boolean isCancelled() {
            // Could support cancellation by closing the connection
            return false;
        }
        @Override
        public final boolean isDone() {
            return .getCount() > 0;
        }
        @Override
        public T get() throws InterruptedExceptionExecutionException {
            .await();
            if ( != null) {
                throw new ExecutionException();
            }
            return ;
        }
        @Override
        public T get(long timeoutTimeUnit unit)
                throws InterruptedExceptionExecutionException,
                TimeoutException {
            boolean latchResult = .await(timeoutunit);
            if (latchResult == false) {
                throw new TimeoutException();
            }
            if ( != null) {
                throw new ExecutionException();
            }
            return ;
        }
    }
    private static final class LongToIntegerFuture implements Future<Integer> {
        private final Future<Longwrapped;
        public LongToIntegerFuture(Future<Longwrapped) {
            this. = wrapped;
        }
        @Override
        public boolean cancel(boolean mayInterruptIfRunning) {
            return .cancel(mayInterruptIfRunning);
        }
        @Override
        public boolean isCancelled() {
            return .isCancelled();
        }
        @Override
        public boolean isDone() {
            return .isDone();
        }
        @Override
        public Integer get() throws InterruptedExceptionExecutionException {
            Long result = .get();
            if (result.longValue() > .) {
                throw new ExecutionException(.notAnInteger(result), null);
            }
            return new Integer(result.intValue());
        }
        @Override
        public Integer get(long timeoutTimeUnit unit)
                throws InterruptedExceptionExecutionException,
                TimeoutException {
            Long result = .get(timeoutunit);
            if (result.longValue() > .) {
                throw new ExecutionException(.notAnInteger(result), null);
            }
            return new Integer(result.intValue());
        }
    }
    private static class SecureIOThreadFactory implements ThreadFactory {
        private AtomicInteger count = new AtomicInteger(0);
        @Override
        public Thread newThread(Runnable r) {
            Thread t = new Thread(r);
            t.setName("WebSocketClient-SecureIO-" + .incrementAndGet());
            // No need to set the context class loader. The threads will be
            // cleaned up when the connection is closed.
            t.setDaemon(true);
            return t;
        }
    }
New to GrepCode? Check out our FAQ X