diff options
author | Brian Carlstrom <bdc@google.com> | 2012-09-21 16:29:30 -0700 |
---|---|---|
committer | Brian Carlstrom <bdc@google.com> | 2012-09-22 17:01:00 -0700 |
commit | 615225a35dbd838210270b282d1196deff643b51 (patch) | |
tree | 5b0334a9e47cfd61d6b9dedb7beae14391ce8b31 | |
parent | 14728671441cc9a1c328b63f3cd4ec939d475000 (diff) |
Add OpenSSLSocketImpl.setSoWriteTimeout to allow SO_SNDTIMEO to be specified
Bug: 6693087
Change-Id: Ie6903168ca0ada4516c55dfab5f7194baf965b4c
4 files changed, 100 insertions, 24 deletions
diff --git a/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/NativeCrypto.java b/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/NativeCrypto.java index 12f87fafa1..65373ff0da 100644 --- a/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/NativeCrypto.java +++ b/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/NativeCrypto.java @@ -645,7 +645,7 @@ public final class NativeCrypto { public static native int SSL_read(int sslNativePointer, FileDescriptor fd, SSLHandshakeCallbacks shc, - byte[] b, int off, int len, int timeoutMillis) + byte[] b, int off, int len, int readTimeoutMillis) throws IOException; /** @@ -654,7 +654,7 @@ public final class NativeCrypto { public static native void SSL_write(int sslNativePointer, FileDescriptor fd, SSLHandshakeCallbacks shc, - byte[] b, int off, int len) + byte[] b, int off, int len, int writeTimeoutMillis) throws IOException; public static native void SSL_interrupt(int sslNativePointer); diff --git a/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/OpenSSLSocketImpl.java b/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/OpenSSLSocketImpl.java index 5ab5aa7cbd..4cc16e6cde 100644 --- a/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/OpenSSLSocketImpl.java +++ b/luni/src/main/java/org/apache/harmony/xnet/provider/jsse/OpenSSLSocketImpl.java @@ -42,7 +42,11 @@ import javax.net.ssl.SSLProtocolException; import javax.net.ssl.SSLSession; import javax.net.ssl.X509TrustManager; import javax.security.auth.x500.X500Principal; +import static libcore.io.OsConstants.*; +import libcore.io.ErrnoException; +import libcore.io.Libcore; import libcore.io.Streams; +import libcore.io.StructTimeval; import org.apache.harmony.security.provider.cert.X509CertImpl; /** @@ -93,7 +97,8 @@ public class OpenSSLSocketImpl * OpenSSLSocketImplWrapper overrides setSoTimeout and * getSoTimeout to delegate to the wrapped socket. */ - private int timeoutMilliseconds = 0; + private int readTimeoutMilliseconds = 0; + private int writeTimeoutMilliseconds = 0; private int handshakeTimeoutMilliseconds = -1; // -1 = same as timeout; 0 = infinite private String wrappedHost; @@ -361,9 +366,11 @@ public class OpenSSLSocketImpl } // Temporarily use a different timeout for the handshake process - int savedTimeoutMilliseconds = getSoTimeout(); + int savedReadTimeoutMilliseconds = getSoTimeout(); + int savedWriteTimeoutMilliseconds = getSoWriteTimeout(); if (handshakeTimeoutMilliseconds >= 0) { setSoTimeout(handshakeTimeoutMilliseconds); + setSoWriteTimeout(handshakeTimeoutMilliseconds); } int sslSessionNativePointer; @@ -399,7 +406,8 @@ public class OpenSSLSocketImpl // Restore the original timeout now that the handshake is complete if (handshakeTimeoutMilliseconds >= 0) { - setSoTimeout(savedTimeoutMilliseconds); + setSoTimeout(savedReadTimeoutMilliseconds); + setSoWriteTimeout(savedWriteTimeoutMilliseconds); } // if not, notifyHandshakeCompletedListeners later in handshakeCompleted() callback @@ -696,7 +704,7 @@ public class OpenSSLSocketImpl return; } NativeCrypto.SSL_write(sslNativePointer, socket.getFileDescriptor$(), - OpenSSLSocketImpl.this, buf, offset, byteCount); + OpenSSLSocketImpl.this, buf, offset, byteCount, writeTimeoutMilliseconds); } } } @@ -827,21 +835,42 @@ public class OpenSSLSocketImpl throw new SocketException("Methods sendUrgentData, setOOBInline are not supported."); } - @Override public void setSoTimeout(int timeoutMilliseconds) throws SocketException { - super.setSoTimeout(timeoutMilliseconds); - this.timeoutMilliseconds = timeoutMilliseconds; + @Override public void setSoTimeout(int readTimeoutMilliseconds) throws SocketException { + super.setSoTimeout(readTimeoutMilliseconds); + this.readTimeoutMilliseconds = readTimeoutMilliseconds; } @Override public int getSoTimeout() throws SocketException { - return timeoutMilliseconds; + return readTimeoutMilliseconds; + } + + /** + * Note write timeouts are not part of the javax.net.ssl.SSLSocket API + */ + public void setSoWriteTimeout(int writeTimeoutMilliseconds) throws SocketException { + this.writeTimeoutMilliseconds = writeTimeoutMilliseconds; + + StructTimeval tv = StructTimeval.fromMillis(writeTimeoutMilliseconds); + try { + Libcore.os.setsockoptTimeval(getFileDescriptor$(), SOL_SOCKET, SO_SNDTIMEO, tv); + } catch (ErrnoException errnoException) { + throw errnoException.rethrowAsSocketException(); + } + } + + /** + * Note write timeouts are not part of the javax.net.ssl.SSLSocket API + */ + public int getSoWriteTimeout() throws SocketException { + return writeTimeoutMilliseconds; } /** * Set the handshake timeout on this socket. This timeout is specified in * milliseconds and will be used only during the handshake process. */ - public void setHandshakeTimeout(int timeoutMilliseconds) throws SocketException { - this.handshakeTimeoutMilliseconds = timeoutMilliseconds; + public void setHandshakeTimeout(int handshakeTimeoutMilliseconds) throws SocketException { + this.handshakeTimeoutMilliseconds = handshakeTimeoutMilliseconds; } @Override public void close() throws IOException { diff --git a/luni/src/main/native/org_apache_harmony_xnet_provider_jsse_NativeCrypto.cpp b/luni/src/main/native/org_apache_harmony_xnet_provider_jsse_NativeCrypto.cpp index 15bfab8fb0..83f28d2c12 100644 --- a/luni/src/main/native/org_apache_harmony_xnet_provider_jsse_NativeCrypto.cpp +++ b/luni/src/main/native/org_apache_harmony_xnet_provider_jsse_NativeCrypto.cpp @@ -3755,7 +3755,7 @@ static jobjectArray NativeCrypto_SSL_get_peer_cert_chain(JNIEnv* env, jclass, ji * cleanly shut down, or THROW_SSLEXCEPTION if an exception should be thrown. */ static int sslRead(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, char* buf, jint len, - int* sslReturnCode, int* sslErrorCode, int timeout_millis) { + int* sslReturnCode, int* sslErrorCode, int read_timeout_millis) { JNI_TRACE("ssl=%p sslRead buf=%p len=%d", ssl, buf, len); if (len == 0) { @@ -3834,7 +3834,7 @@ static int sslRead(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, char* b // Need to wait for availability of underlying layer, then retry. case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: { - int selectResult = sslSelect(env, sslError, fdObject, appData, timeout_millis); + int selectResult = sslSelect(env, sslError, fdObject, appData, read_timeout_millis); if (selectResult == THROWN_EXCEPTION) { return THROWN_EXCEPTION; } @@ -3886,11 +3886,11 @@ static int sslRead(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, char* b */ static jint NativeCrypto_SSL_read(JNIEnv* env, jclass, jint ssl_address, jobject fdObject, jobject shc, jbyteArray b, jint offset, jint len, - jint timeout_millis) + jint read_timeout_millis) { SSL* ssl = to_SSL(env, ssl_address, true); - JNI_TRACE("ssl=%p NativeCrypto_SSL_read fd=%p shc=%p b=%p offset=%d len=%d timeout_millis=%d", - ssl, fdObject, shc, b, offset, len, timeout_millis); + JNI_TRACE("ssl=%p NativeCrypto_SSL_read fd=%p shc=%p b=%p offset=%d len=%d read_timeout_millis=%d", + ssl, fdObject, shc, b, offset, len, read_timeout_millis); if (ssl == NULL) { return 0; } @@ -3914,7 +3914,7 @@ static jint NativeCrypto_SSL_read(JNIEnv* env, jclass, jint ssl_address, jobject int sslErrorCode = SSL_ERROR_NONE;; int ret = sslRead(env, ssl, fdObject, shc, reinterpret_cast<char*>(bytes.get() + offset), len, - &returnCode, &sslErrorCode, timeout_millis); + &returnCode, &sslErrorCode, read_timeout_millis); int result; switch (ret) { @@ -3954,7 +3954,7 @@ static jint NativeCrypto_SSL_read(JNIEnv* env, jclass, jint ssl_address, jobject * cleanly shut down, or THROW_SSLEXCEPTION if an exception should be thrown. */ static int sslWrite(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, const char* buf, jint len, - int* sslReturnCode, int* sslErrorCode) { + int* sslReturnCode, int* sslErrorCode, int write_timeout_millis) { JNI_TRACE("ssl=%p sslWrite buf=%p len=%d", ssl, buf, len); if (len == 0) { @@ -4040,7 +4040,7 @@ static int sslWrite(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, const // it's also not standard Java behavior, so we wait forever here. case SSL_ERROR_WANT_READ: case SSL_ERROR_WANT_WRITE: { - int selectResult = sslSelect(env, sslError, fdObject, appData, 0); + int selectResult = sslSelect(env, sslError, fdObject, appData, write_timeout_millis); if (selectResult == THROWN_EXCEPTION) { return THROWN_EXCEPTION; } @@ -4091,7 +4091,7 @@ static int sslWrite(JNIEnv* env, SSL* ssl, jobject fdObject, jobject shc, const * OpenSSL write function (2): write into buffer at offset n chunks. */ static void NativeCrypto_SSL_write(JNIEnv* env, jclass, jint ssl_address, jobject fdObject, - jobject shc, jbyteArray b, jint offset, jint len) + jobject shc, jbyteArray b, jint offset, jint len, jint write_timeout_millis) { SSL* ssl = to_SSL(env, ssl_address, true); JNI_TRACE("ssl=%p NativeCrypto_SSL_write fd=%p shc=%p b=%p offset=%d len=%d", @@ -4118,7 +4118,7 @@ static void NativeCrypto_SSL_write(JNIEnv* env, jclass, jint ssl_address, jobjec int returnCode = 0; int sslErrorCode = SSL_ERROR_NONE; int ret = sslWrite(env, ssl, fdObject, shc, reinterpret_cast<const char*>(bytes.get() + offset), - len, &returnCode, &sslErrorCode); + len, &returnCode, &sslErrorCode, write_timeout_millis); switch (ret) { case THROW_SSLEXCEPTION: @@ -4487,7 +4487,7 @@ static JNINativeMethod sNativeCryptoMethods[] = { NATIVE_METHOD(NativeCrypto, SSL_get_certificate, "(I)[[B"), NATIVE_METHOD(NativeCrypto, SSL_get_peer_cert_chain, "(I)[[B"), NATIVE_METHOD(NativeCrypto, SSL_read, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "[BIII)I"), - NATIVE_METHOD(NativeCrypto, SSL_write, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "[BII)V"), + NATIVE_METHOD(NativeCrypto, SSL_write, "(I" FILE_DESCRIPTOR SSL_CALLBACKS "[BIII)V"), NATIVE_METHOD(NativeCrypto, SSL_interrupt, "(I)V"), NATIVE_METHOD(NativeCrypto, SSL_shutdown, "(I" FILE_DESCRIPTOR SSL_CALLBACKS ")V"), NATIVE_METHOD(NativeCrypto, SSL_free, "(I)V"), diff --git a/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java b/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java index 8c9239eb05..4095081fa0 100644 --- a/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java +++ b/luni/src/test/java/libcore/javax/net/ssl/SSLSocketTest.java @@ -19,6 +19,7 @@ package libcore.javax.net.ssl; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; +import java.lang.reflect.Method; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; @@ -1012,7 +1013,7 @@ public class SSLSocketTest extends TestCase { assertEquals(0, wrapping.getSoTimeout()); // setting wrapper sets underlying and ... - int expectedTimeoutMillis = 1000; // Using a small value such as 10 was affected by rounding + int expectedTimeoutMillis = 1000; // 10 was too small because it was affected by rounding wrapping.setSoTimeout(expectedTimeoutMillis); assertEquals(expectedTimeoutMillis, wrapping.getSoTimeout()); assertEquals(expectedTimeoutMillis, underlying.getSoTimeout()); @@ -1050,6 +1051,52 @@ public class SSLSocketTest extends TestCase { listening.close(); } + public void test_SSLSocket_setSoWriteTimeout() throws Exception { + if (StandardNames.IS_RI) { + // RI does not support write timeout on sockets + return; + } + + final TestSSLContext c = TestSSLContext.create(); + SSLSocket client = (SSLSocket) c.clientContext.getSocketFactory().createSocket(c.host, + c.port); + final SSLSocket server = (SSLSocket) c.serverSocket.accept(); + ExecutorService executor = Executors.newSingleThreadExecutor(); + Future<Void> future = executor.submit(new Callable<Void>() { + @Override public Void call() throws Exception { + server.startHandshake(); + return null; + } + }); + executor.shutdown(); + client.startHandshake(); + + // Reflection is used so this can compile on the RI + String expectedClassName = "org.apache.harmony.xnet.provider.jsse.OpenSSLSocketImpl"; + Class actualClass = client.getClass(); + assertEquals(expectedClassName, actualClass.getName()); + Method setSoWriteTimeout = actualClass.getMethod("setSoWriteTimeout", + new Class[] { Integer.TYPE }); + setSoWriteTimeout.invoke(client, 1); + + // Try to make the size smaller (it can be 512k or even megabytes). + // Note that it may not respect your request, so read back the actual value. + int sendBufferSize = 1024; + client.setSendBufferSize(sendBufferSize); + sendBufferSize = client.getSendBufferSize(); + + try { + client.getOutputStream().write(new byte[sendBufferSize + 1]); + fail(); + } catch (SocketTimeoutException expected) { + } + + future.get(); + client.close(); + server.close(); + c.close(); + } + public void test_SSLSocket_interrupt() throws Exception { ServerSocket listening = new ServerSocket(0); |