Start line:  
End line:  

Snippet Preview

Snippet HTML Code

Stack Overflow Questions
BEGIN LICENSE BLOCK ***** Version: CPL 1.0/GPL 2.0/LGPL 2.1 The contents of this file are subject to the Common Public License Version 1.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.eclipse.org/legal/cpl-v10.html Software distributed under the License is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License for the specific language governing rights and limitations under the License. Copyright (C) 2006, 2007 Ola Bini <ola@ologix.com> Alternatively, the contents of this file may be used under the terms of either of the GNU General Public License Version 2 or later (the "GPL"), or the GNU Lesser General Public License Version 2.1 or later (the "LGPL"), in which case the provisions of the GPL or the LGPL are applicable instead of those above. If you wish to allow use of your version of this file only under the terms of either the GPL or the LGPL, and not to allow others to use your version of this file under the terms of the CPL, indicate your decision by deleting the provisions above and replace them with the notice and other provisions required by the GPL or the LGPL. If you do not delete the provisions above, a recipient may use your version of this file under the terms of any one of the CPL, the GPL or the LGPL. END LICENSE BLOCK ***
 
 package org.jruby.ext.openssl;
 
 import java.util.Set;
 
 
 import org.jruby.Ruby;

Author(s):
Ola Bini
 
 public class SSLSocket extends RubyObject {
     private static final long serialVersionUID = -2276327900350542644L;
 
     private static ObjectAllocator SSLSOCKET_ALLOCATOR = new ObjectAllocator() {
         public IRubyObject allocate(Ruby runtimeRubyClass klass) {
             return new SSLSocket(runtimeklass);
         }
     };
 
     private static RubyObjectAdapter api = JavaEmbedUtils.newObjectAdapter();
     
     public static void createSSLSocket(Ruby runtimeRubyModule mSSL) {
         RubyClass cSSLSocket = mSSL.defineClassUnder("SSLSocket",runtime.getObject(),);
         cSSLSocket.addReadWriteAttribute(runtime.getCurrentContext(), "io");
         cSSLSocket.addReadWriteAttribute(runtime.getCurrentContext(), "context");
         cSSLSocket.addReadWriteAttribute(runtime.getCurrentContext(), "sync_close");
         cSSLSocket.addReadWriteAttribute(runtime.getCurrentContext(), "hostname");
         cSSLSocket.defineAlias("to_io","io");
         cSSLSocket.defineAnnotatedMethods(SSLSocket.class);
     }
 
     public SSLSocket(Ruby runtimeRubyClass type) {
         super(runtime,type);
          = .;
     }
     
    public static RaiseException newSSLError(Ruby runtimeString message) {
        return Utils.newError(runtime"OpenSSL::SSL::SSLError"messagefalse);
    }
    public static RaiseException newSSLErrorReadable(Ruby runtimeString message) {
        return Utils.newError(runtime"OpenSSL::SSL::SSLErrorReadable"messagefalse);
    }
    public static RaiseException newSSLErrorWritable(Ruby runtimeString message) {
        return Utils.newError(runtime"OpenSSL::SSL::SSLErrorWritable"messagefalse);
    }
    private SSLEngine engine;
    private RubyIO io = null;
    private ByteBuffer peerAppData;
    private ByteBuffer peerNetData;
    private ByteBuffer netData;
    private ByteBuffer dummy;
    
    private boolean initialHandshake = false;
    private SSLEngineResult.Status status = null;
    int verifyResult;
    
    @JRubyMethod(name = "initialize", rest = true, frame = true)
    public IRubyObject _initialize(IRubyObject[] argsBlock unused) {
        if (Arity.checkArgumentCount(getRuntime(), args, 1, 2) == 1) {
            RubyClass sslContext = Utils.getClassFromPath(getRuntime(), "OpenSSL::SSL::SSLContext");
             = (org.jruby.ext.openssl.SSLContext.callMethod(sslContext"new");
        } else {
             = (org.jruby.ext.openssl.SSLContextargs[1];
        }
        Utils.checkKind(getRuntime(), args[0], "IO");
         = (RubyIOargs[0];
        .callMethod(this"io=");
        .callMethod(this"hostname="getRuntime().newString(""));
        // This is a bit of a hack: SSLSocket should share code with RubyBasicSocket, which always sets sync to true.
        // Instead we set it here for now.
        .callMethod("sync="getRuntime().getTrue());
        .callMethod(this"context=");
        .callMethod(this"sync_close="getRuntime().getFalse());
        .setup();
        return .callSuper(thisargs);
    }
        if(null == ) {
            Socket socket = getSocketChannel().socket();
            // Server Name Indication (SNI) RFC 3546
            // SNI support will not be attempted unless hostname is explicitly set by the caller
            String peerHost = .callMethod(this,"hostname").convertToString().toString();
            int peerPort = socket.getPort();
             = .createSSLEngine(peerHostpeerPort);
            SSLSession session = .getSession();
             = ByteBuffer.allocate(session.getPacketBufferSize());
             = ByteBuffer.allocate(session.getApplicationBufferSize());		
             = ByteBuffer.allocate(session.getPacketBufferSize());
            .limit(0);
            .limit(0);
            .limit(0);
             = ByteBuffer.allocate(0);
        }
    }
    public IRubyObject connect(ThreadContext context) {
        return connectCommon(contexttrue);
    }
    public IRubyObject connect_nonblock(ThreadContext context) {
        return connectCommon(contextfalse);
    }
    private IRubyObject connectCommon(ThreadContext contextboolean blocking) {
        Ruby runtime = context.runtime;
        if (!.isProtocolForClient()) {
            throw newSSLError(runtime"called a function you should not call");
        }
        try {
            if (!) {
                ossl_ssl_setup();
                .setUseClientMode(true);
                .beginHandshake();
                 = .getHandshakeStatus();
                 = true;
            }
            doHandshake(blocking);
        } catch(SSLHandshakeException e) {
            // unlike server side, client should close outbound channel even if
            // we have remaining data to be sent.
            forceClose();
            Throwable v = e;
            while(v.getCause() != null && (v instanceof SSLHandshakeException)) {
                v = v.getCause();
            }
            throw SSL.newSSLError(runtimev);
        } catch (NoSuchAlgorithmException ex) {
            forceClose();
            throw SSL.newSSLError(runtimeex);
        } catch (KeyManagementException ex) {
            forceClose();
            throw SSL.newSSLError(runtimeex);
        } catch (IOException ex) {
            forceClose();
            throw SSL.newSSLError(runtimeex);
        }
        return this;
    }
    public IRubyObject accept(ThreadContext context) {
        return acceptCommon(contexttrue);
    }
    public IRubyObject accept_nonblock(ThreadContext context) {
        return acceptCommon(contextfalse);
    }
    public IRubyObject acceptCommon(ThreadContext contextboolean blocking) {
        Ruby runtime = context.runtime;
        if (!.isProtocolForServer()) {
            throw newSSLError(runtime"called a function you should not call");
        }
        try {
            int vfy = 0;
            if (!) {
                ossl_ssl_setup();
                .setUseClientMode(false);
                if(!.isNil() && !.callMethod(context,"verify_mode").isNil()) {
                    vfy = RubyNumeric.fix2int(.callMethod(context,"verify_mode"));
                    if(vfy == 0) { //VERIFY_NONE
                        .setNeedClientAuth(false);
                        .setWantClientAuth(false);
                    }
                    if((vfy & 1) != 0) { //VERIFY_PEER
                        .setWantClientAuth(true);
                    }
                    if((vfy & 2) != 0) { //VERIFY_FAIL_IF_NO_PEER_CERT
                        .setNeedClientAuth(true);
                    }
                }
                .beginHandshake();
                 = .getHandshakeStatus();
                 = true;
            }
            doHandshake(blocking);
        } catch(SSLHandshakeException e) {
            throw SSL.newSSLError(runtimee);
        } catch (NoSuchAlgorithmException ex) {
            throw SSL.newSSLError(runtimeex);
        } catch (KeyManagementException ex) {
            throw SSL.newSSLError(runtimeex);
        } catch (IOException ex) {
            throw SSL.newSSLError(runtimeex);
        }
        return this;
    }
    public IRubyObject verify_result() {
        if ( == null) {
            getRuntime().getWarnings().warn("SSL session is not started yet.");
            return getRuntime().getNil();
        }
        return getRuntime().newFixnum();
    }
    // This select impl is a copy of RubyThread.select, then blockingLock is
    // removed. This impl just set
    // SelectableChannel.configureBlocking(false) permanently instead of setting
    // temporarily. SSLSocket requires wrapping IO to be selectable so it should
    // be OK to set configureBlocking(false) permanently.
    private boolean waitSelect(final int operationsfinal boolean blockingthrows IOException {
        if (!(.getChannel() instanceof SelectableChannel)) {
            return true;
        }
        final Ruby runtime = getRuntime();
        RubyThread thread = runtime.getCurrentContext().getThread();
        SelectableChannel selectable = (SelectableChannel).getChannel();
        selectable.configureBlocking(false);
        final Selector selector = runtime.getSelectorPool().get();
        final SelectionKey key = selectable.register(selectoroperations);
        try {
            .addBlockingThread(thread);
            final int[] result = new int[1];
            thread.executeBlockingTask(new RubyThread.BlockingTask() {
                public void run() throws InterruptedException {
                    try {
                        if (!blocking) {
                            result[0] = selector.selectNow();
                            if (result[0] == 0) {
                                if ((operations & .) != 0 && (operations & .) != 0) {
                                    if (key.isReadable()) {
                                        writeWouldBlock();
                                    } else if (key.isWritable()) {
                                        readWouldBlock();
                                    } else { //neither, pick one
                                        readWouldBlock();
                                    }
                                } else if ((operations & .) != 0) {
                                    readWouldBlock();
                                } else if ((operations & .) != 0) {
                                    writeWouldBlock();
                                }
                            }
                        } else {
                            result[0] = selector.select();
                        }
                    } catch (IOException ioe) {
                        throw runtime.newRuntimeError("Error with selector: " + ioe.getMessage());
                    }
                }
                public void wakeup() {
                    selector.wakeup();
                }
            });
            if (result[0] >= 1) {
                Set<SelectionKeykeySet = selector.selectedKeys();
                if (keySet.iterator().next() == key) {
                    return true;
                }
            }
            return false;
        } catch (InterruptedException ie) {
            return false;
        } finally {
            // Note: I don't like ignoring these exceptions, but it's
            // unclear how likely they are to happen or what damage we
            // might do by ignoring them. Note that the pieces are separate
            // so that we can ensure one failing does not affect the others
            // running.
            // clean up the key in the selector
            try {
                if (key != nullkey.cancel();
                if (selector != nullselector.selectNow();
            } catch (Exception e) {
                // ignore
            }
            // shut down and null out the selector
            try {
                if (selector != null) {
                    runtime.getSelectorPool().put(selector);
                }
            } catch (Exception e) {
                // ignore
            }
            // remove this thread as a blocker against the given IO
            .removeBlockingThread(thread);
            // clear thread state from blocking call
            thread.afterBlockingCall();
        }
    }
    private void readWouldBlock() {
        Ruby runtime = getRuntime();
        throw newSSLErrorReadable(runtime"read would block");
    }
    private void writeWouldBlock() {
        Ruby runtime = getRuntime();
        throw newSSLErrorWritable(runtime"write would block");
    }
    private void doHandshake(boolean blockingthrows IOException {
        while (true) {
            SSLEngineResult res;
            boolean ready = waitSelect(. | .blocking);
            // if not blocking, raise EAGAIN
            if (!blocking && !ready) {
                Ruby runtime = getRuntime();
                throw runtime.is1_9() ?
                        runtime.newErrnoEAGAINWritableError("Resource temporarily unavailable") :
                        runtime.newErrnoEAGAINError("Resource temporarily unavailable");
            }
            // otherwise, proceed as before
            switch () {
            case :
                if () {
                    finishInitialHandshake();
                }
                return;
            case :
                doTasks();
                break;
            case :
                if (readAndUnwrap(blocking) == -1 &&  != ..) {
                    throw new SSLHandshakeException("Socket closed");
                }
                // during initialHandshake, calling readAndUnwrap that results UNDERFLOW
                // does not mean writable. we explicitly wait for readable channel to avoid
                // busy loop.
                if ( &&  == ..) {
                    waitSelect(.blocking);
                }
                break;
            case :
                if (.hasRemaining()) {
                    while (flushData(blocking)) {
                    }
                }
                .clear();
                res = .wrap();
                 = res.getHandshakeStatus();
                .flip();
                flushData(blocking);
                break;
            case :
                // Opposite side could close while unwrapping. Handle this as same as FINISHED
                return;
            default:
                throw new IllegalStateException("Unknown handshaking status: " + );
            }
        }
    }
    private void doTasks() {
        Runnable task;
        while ((task = .getDelegatedTask()) != null) {
            task.run();
        }
         = .getHandshakeStatus();
    }
    private boolean flushData(boolean blockingthrows IOException {
        try {
            writeToChannel(blocking);
        } catch (IOException ioe) {
            .position(.limit());
            throw ioe;
        }
        if (.hasRemaining()) {
            return false;
        }  else {
            return true;
        }
    }
    private int writeToChannel(ByteBuffer bufferboolean blockingthrows IOException {
        int totalWritten = 0;
        while (buffer.hasRemaining()) {
            totalWritten += getSocketChannel().write(buffer);
            if (!blockingbreak// don't continue attempting to read
        }
        return totalWritten;
    }
    private void finishInitialHandshake() {
         = false;
    }
    public int write(ByteBuffer srcboolean blockingthrows SSLExceptionIOException {
        if() {
            throw new IOException("Writing not possible during handshake");
        }
        SelectableChannel selectable = getSocketChannel();
        boolean blockingMode = selectable.isBlocking();
        if (!blockingselectable.configureBlocking(false);
        try {
            if(.hasRemaining()) {
                flushData(blocking);
            }
            .clear();
            SSLEngineResult res = .wrap(src);
            .flip();
            flushData(blocking);
            return res.bytesConsumed();
        } finally {
            if (!blockingselectable.configureBlocking(blockingMode);
        }
    }
    public int read(ByteBuffer dstboolean blockingthrows IOException {
        if() {
            return 0;
        }
        if (.isInboundDone()) {
            return -1;
        }
        if (!.hasRemaining()) {
            int appBytesProduced = readAndUnwrap(blocking);
            if (appBytesProduced == -1 || appBytesProduced == 0) {
                return appBytesProduced;
            }
        }
        int limit = Math.min(.remaining(), dst.remaining());
        .get(dst.array(), dst.arrayOffset(), limit);
        dst.position(dst.arrayOffset() + limit);
        return limit;
    }
    private int readAndUnwrap(boolean blockingthrows IOException {
        int bytesRead = getSocketChannel().read();
        if (bytesRead == -1) {
            if (!.hasRemaining() || ( == ..)) {
                closeInbound();
                return -1;
            }
            // inbound channel has been already closed but closeInbound() must
            // be defered till the last engine.unwrap() call.
            // peerNetData could not be empty.
        }
        .clear();
        .flip();
        SSLEngineResult res;
        do {
            res = .unwrap();
        } while (res.getStatus() == .. &&
				res.bytesProduced() == 0);
            finishInitialHandshake();
        }
        if(.position() == 0 && 
            res.getStatus() == .. &&
            .hasRemaining()) {
            res = .unwrap();
        }
         = res.getStatus();
         = res.getHandshakeStatus();
        if (bytesRead == -1 && !.hasRemaining()) {
            // now it's safe to call closeInbound().
            closeInbound();
        }
        if( == ..) {
            doShutdown();
            return -1;
        }
        .compact();
        .flip();
                                  == .. ||
                                  == ..)) {
            doHandshake(blocking);
        }
        return .remaining();
    }
    private void closeInbound() {
        try {
            .closeInbound();
        } catch (SSLException ssle) {
            // ignore any error on close. possibly an error like this;
            // Inbound closed before receiving peer's close_notify: possible truncation attack?
        }
    }
    private void doShutdown() throws IOException {
        if (.isOutboundDone()) {
            return;
        }
        .clear();
        try {
            .wrap();
        } catch(Exception e1) {
            return;
        }
        .flip();
        flushData(true);
    }
    private IRubyObject do_sysread(ThreadContext contextIRubyObject[] argsboolean blocking) {
        Ruby runtime = context.runtime;
        int len = RubyNumeric.fix2int(args[0]);
        RubyString str = null;
        
        if (args.length == 2 && !args[1].isNil()) {
            str = args[1].convertToString();
        } else {
            str = getRuntime().newString("");
        }
        if(len == 0) {
            str.clear();
            return str;
        }
        if (len < 0) {
            throw runtime.newArgumentError("negative string size (or size too big)");
        }
        try {
            // So we need to make sure to only block when there is no data left to process
            if ( == null || !(.hasRemaining() || .position() > 0)) {
                waitSelect(.blocking);
            }
            ByteBuffer dst = ByteBuffer.allocate(len);
            int rr = -1;
            // ensure >0 bytes read; sysread is blocking read.
            while (rr <= 0) {
                if ( == null) {
                    rr = getSocketChannel().read(dst);
                } else {
                    rr = read(dstblocking);
                }
                if (rr == -1) {
                    throw getRuntime().newEOFError();
                }
            }
            byte[] bss = new byte[rr];
            dst.position(dst.position() - rr);
            dst.get(bss);
            str.setValue(new ByteList(bss));
            return str;
        } catch (IOException ioe) {
            throw getRuntime().newIOError(ioe.getMessage());
        }
    }
    @JRubyMethod(rest = true, required = 1, optional = 1)
    public IRubyObject sysread(ThreadContext contextIRubyObject[] args) {
        return do_sysread(contextargstrue);
    }
    
    @JRubyMethod(rest = true, required = 1, optional = 1)
    public IRubyObject sysread_nonblock(ThreadContext contextIRubyObject[] args) {
        return do_sysread(contextargsfalse);
    }
    private IRubyObject do_syswrite(ThreadContext contextIRubyObject argboolean blocking)  {
        Ruby runtime = context.runtime;
        try {
            checkClosed();
            waitSelect(.blocking);
            ByteList bls = arg.convertToString().getByteList();
            ByteBuffer b1 = ByteBuffer.wrap(bls.getUnsafeBytes(), bls.getBegin(), bls.getRealSize());
            int written;
            if( == null) {
                written = writeToChannel(b1blocking);
            } else {
                written = write(b1blocking);
            }
            ((RubyIO).callMethod(this,"io")).flush();
            return getRuntime().newFixnum(written);
        } catch (IOException ioe) {
            throw runtime.newIOError(ioe.getMessage());
        }
    }
    public IRubyObject syswrite(ThreadContext contextIRubyObject arg) {
        return do_syswrite(contextargtrue);
    }
    public IRubyObject syswrite_nonblock(ThreadContext contextIRubyObject arg) {
        return do_syswrite(contextargfalse);
    }
    private void checkClosed() {
        if (!getSocketChannel().isOpen()) {
            throw getRuntime().newIOError("closed stream");
        }
    }
    // do shutdown even if we have remaining data to be sent.
    // call this when you get an exception from client side.
    private void forceClose() {
        close(true);
    }
    private void close(boolean force)  {
        if ( == nullthrow getRuntime().newEOFError();
        .closeOutbound();
        if (!force && .hasRemaining()) {
            return;
        } else {
            try {
                doShutdown();
            } catch (IOException ex) {
                // ignore?
            }
        }
    }
    public IRubyObject sysclose()  {
        // no need to try shutdown when it's a server
        ThreadContext tc = getRuntime().getCurrentContext();
        if(callMethod(tc,"sync_close").isTrue()) {
            callMethod(tc,"io").callMethod(tc,"close");
        }
        return getRuntime().getNil();
    }
    public IRubyObject cert() {
        if ( == null) {
            return getRuntime().getNil();
        }
        try {
            Certificate[] cert = .getSession().getLocalCertificates();
            if (cert != null && cert.length > 0) {
                return X509Cert.wrap(getRuntime(), cert[0]);
            }
        } catch (CertificateEncodingException ex) {
            throw X509Cert.newCertificateError(getRuntime(), ex);
        }
        return getRuntime().getNil();
    }
    public IRubyObject peer_cert() {
        if ( == null) {
            return getRuntime().getNil();
        }
        try {
            Certificate[] cert = .getSession().getPeerCertificates();
            if (cert.length > 0) {
                return X509Cert.wrap(getRuntime(), cert[0]);
            }
        } catch (CertificateEncodingException ex) {
            throw X509Cert.newCertificateError(getRuntime(), ex);
        } catch (SSLPeerUnverifiedException ex) {
            if (getRuntime().isVerbose()) {
                getRuntime().getWarnings().warning(String.format("%s: %s"ex.getClass().getName(), ex.getMessage()));
            }
        }
        return getRuntime().getNil();
    }
    public IRubyObject peer_cert_chain() {
        if ( == null) {
            return getRuntime().getNil();
        }
        try {
            javax.security.cert.Certificate[] certs = .getSession().getPeerCertificateChain();
            RubyArray arr = getRuntime().newArray(certs.length);
            for (int i = 0; i < certs.lengthi++) {
                arr.add(X509Cert.wrap(getRuntime(), certs[i]));
            }
            return arr;
        } catch (javax.security.cert.CertificateEncodingException e) {
            throw X509Cert.newCertificateError(getRuntime(), e);
        } catch (SSLPeerUnverifiedException ex) {
            if (getRuntime().isVerbose()) {
                getRuntime().getWarnings().warning(String.format("%s: %s"ex.getClass().getName(), ex.getMessage()));
            }
        }
        return getRuntime().getNil();
    }
    public IRubyObject cipher() {
        return getRuntime().newString(.getSession().getCipherSuite());
    }
    public IRubyObject state() {
        ..println("WARNING: unimplemented method called: SSLSocket#state");
        return getRuntime().getNil();
    }
    public IRubyObject pending() {
        ..println("WARNING: unimplemented method called: SSLSocket#pending");
        return getRuntime().getNil();
    }
    public IRubyObject session_reused_p() {
        throw new UnsupportedOperationException();
    }
    public synchronized IRubyObject session_set(IRubyObject aSession) {
        throw new UnsupportedOperationException();
    }
    
    private SocketChannel getSocketChannel() {
        return (SocketChannel.getChannel();
    }
}// SSLSocket