diff --git a/src/main/java/io/vertx/core/http/WebSocketBase.java b/src/main/java/io/vertx/core/http/WebSocketBase.java index ce68983c9..a0245e1e9 100644 --- a/src/main/java/io/vertx/core/http/WebSocketBase.java +++ b/src/main/java/io/vertx/core/http/WebSocketBase.java @@ -15,6 +15,7 @@ import io.vertx.codegen.annotations.*; import io.vertx.core.AsyncResult; import io.vertx.core.Future; import io.vertx.core.Handler; +import io.vertx.core.MultiMap; import io.vertx.core.buffer.Buffer; import io.vertx.core.net.SocketAddress; import io.vertx.core.streams.ReadStream; @@ -99,6 +100,15 @@ public interface WebSocketBase extends ReadStream, WriteStream { */ String closeReason(); + /** + * Returns the HTTP response headers during the websocket connection handler. + *

+ * After the completion handler callback has completed the response headers will be {@code null} + * + * @return the response headers + */ + MultiMap headers(); + /** * Write a WebSocket frame to the connection * diff --git a/src/main/java/io/vertx/core/http/impl/Http1xClientConnection.java b/src/main/java/io/vertx/core/http/impl/Http1xClientConnection.java index 3ca9f0c06..6e05f5ab8 100644 --- a/src/main/java/io/vertx/core/http/impl/Http1xClientConnection.java +++ b/src/main/java/io/vertx/core/http/impl/Http1xClientConnection.java @@ -743,8 +743,10 @@ class Http1xClientConnection extends Http1xConnectionBase impleme if (metrics != null) { ws.setMetric(metrics.connected(endpointMetric, metric(), ws)); } + ws.headers(new HeadersAdaptor(ar.result().headers())); } wsHandler.handle(res); + ws.headers(null); }); }); p.addBefore("handler", "handshakeCompleter", handshakeInboundHandler); diff --git a/src/main/java/io/vertx/core/http/impl/WebSocketHandshakeInboundHandler.java b/src/main/java/io/vertx/core/http/impl/WebSocketHandshakeInboundHandler.java index f572065d3..88160efa1 100644 --- a/src/main/java/io/vertx/core/http/impl/WebSocketHandshakeInboundHandler.java +++ b/src/main/java/io/vertx/core/http/impl/WebSocketHandshakeInboundHandler.java @@ -33,12 +33,12 @@ import io.vertx.core.http.WebsocketRejectedException; */ class WebSocketHandshakeInboundHandler extends ChannelInboundHandlerAdapter { - private final Handler> wsHandler; + private final Handler> wsHandler; private final WebSocketClientHandshaker handshaker; private ChannelHandlerContext chctx; private FullHttpResponse response; - WebSocketHandshakeInboundHandler(WebSocketClientHandshaker handshaker, Handler> wsHandler) { + WebSocketHandshakeInboundHandler(WebSocketClientHandshaker handshaker, Handler> wsHandler) { this.handshaker = handshaker; this.wsHandler = wsHandler; } @@ -76,20 +76,20 @@ class WebSocketHandshakeInboundHandler extends ChannelInboundHandlerAdapter { // remove decompressor as its not needed anymore once connection was upgraded to websockets ctx.pipeline().remove(handler); } - Future fut = handshakeComplete(response); + Future fut = handshakeComplete(response); wsHandler.handle(fut); } } } } - private Future handshakeComplete(FullHttpResponse response) { + private Future handshakeComplete(FullHttpResponse response) { if (response.status().code() != 101) { return Future.failedFuture(new WebsocketRejectedException(response.status().code())); } else { try { handshaker.finishHandshake(chctx.channel(), response); - return Future.succeededFuture(); + return Future.succeededFuture(response); } catch (WebSocketHandshakeException e) { return Future.failedFuture(e); } diff --git a/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java b/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java index d0a9907cb..fe9d88f24 100644 --- a/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java +++ b/src/main/java/io/vertx/core/http/impl/WebSocketImplBase.java @@ -14,9 +14,10 @@ package io.vertx.core.http.impl; import io.netty.buffer.ByteBuf; import io.vertx.codegen.annotations.Nullable; import io.vertx.core.AsyncResult; +import io.vertx.core.Promise; import io.vertx.core.Future; import io.vertx.core.Handler; -import io.vertx.core.Promise; +import io.vertx.core.MultiMap; import io.vertx.core.buffer.Buffer; import io.vertx.core.eventbus.EventBus; import io.vertx.core.eventbus.Message; @@ -66,6 +67,7 @@ public abstract class WebSocketImplBase implements WebS protected boolean closed; private Short closeStatusCode; private String closeReason; + private MultiMap headers; WebSocketImplBase(ContextInternal context, Http1xConnectionBase conn, boolean supportsContinuation, int maxWebSocketFrameSize, int maxWebSocketMessageSize) { @@ -224,6 +226,19 @@ public abstract class WebSocketImplBase implements WebS } } + @Override + public MultiMap headers() { + synchronized(conn) { + return headers; + } + } + + void headers(MultiMap responseHeaders) { + synchronized(conn) { + this.headers = responseHeaders; + } + } + @Override public Future writeBinaryMessage(Buffer data) { Promise promise = Promise.promise(); diff --git a/src/test/java/io/vertx/core/http/WebSocketTest.java b/src/test/java/io/vertx/core/http/WebSocketTest.java index 6db94072e..3ea9e7267 100644 --- a/src/test/java/io/vertx/core/http/WebSocketTest.java +++ b/src/test/java/io/vertx/core/http/WebSocketTest.java @@ -16,15 +16,8 @@ import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; import io.netty.handler.codec.http.websocketx.WebSocket13FrameDecoder; import io.netty.handler.codec.http.websocketx.WebSocket13FrameEncoder; import io.netty.handler.codec.http.websocketx.WebSocketHandshakeException; -import io.vertx.core.AbstractVerticle; -import io.vertx.core.AsyncResult; -import io.vertx.core.Context; -import io.vertx.core.DeploymentOptions; -import io.vertx.core.Future; -import io.vertx.core.Handler; +import io.vertx.core.*; import io.vertx.core.Promise; -import io.vertx.core.Vertx; -import io.vertx.core.VertxOptions; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.impl.FrameType; import io.vertx.core.http.impl.ws.WebSocketFrameImpl; @@ -1769,6 +1762,34 @@ public class WebSocketTest extends VertxTestBase { await(); } + @Test + public void testReceiveHttpResponseHeadersOnClient() { + server = vertx.createHttpServer(new HttpServerOptions().setPort(DEFAULT_HTTP_PORT)).requestHandler(req -> { + handshakeWithCookie(req); + }); + AtomicReference websocketRef = new AtomicReference(); + server.listen(ar -> { + assertTrue(ar.succeeded()); + client.webSocket(DEFAULT_HTTP_PORT, DEFAULT_HTTP_HOST, "/some/path", onSuccess(ws -> { + MultiMap entries = ws.headers(); + assertNotNull(entries); + assertFalse(entries.isEmpty()); + assertEquals("websocket".toLowerCase(), entries.get("Upgrade").toLowerCase()); + assertEquals("upgrade".toLowerCase(), entries.get("Connection").toLowerCase()); + Set cookiesToSet = new HashSet(entries.getAll("Set-Cookie")); + assertEquals(2, cookiesToSet.size()); + assertTrue(cookiesToSet.contains("SERVERID=test-server-id")); + assertTrue(cookiesToSet.contains("JSONID=test-json-id")); + websocketRef.set(ws); + vertx.runOnContext(v -> { + assertNull(ws.headers()); + testComplete(); + }); + })); + }); + await(); + } + @Test public void testUpgrade() { testUpgrade(false); @@ -1913,6 +1934,31 @@ public class WebSocketTest extends VertxTestBase { testRaceConditionWithWebsocketClient(fut.get()); } + private NetSocket handshakeWithCookie(HttpServerRequest req) { + NetSocket so = req.netSocket(); + try { + MessageDigest digest = MessageDigest.getInstance("SHA-1"); + byte[] inputBytes = (req.getHeader("Sec-WebSocket-Key") + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").getBytes(); + digest.update(inputBytes); + byte[] hashedBytes = digest.digest(); + byte[] accept = Base64.getEncoder().encode(hashedBytes); + Buffer data = Buffer.buffer(); + data.appendString("HTTP/1.1 101 Switching Protocols\r\n"); + data.appendString("Upgrade: websocket\r\n"); + data.appendString("Connection: upgrade\r\n"); + data.appendString("Sec-WebSocket-Accept: " + new String(accept) + "\r\n"); + data.appendString("Set-Cookie: SERVERID=test-server-id\r\n"); + data.appendString("Set-Cookie: JSONID=test-json-id\r\n"); + data.appendString("\r\n"); + so.write(data); + return so; + } catch (NoSuchAlgorithmException e) { + req.response().setStatusCode(500).end(); + fail(e.getMessage()); + return null; + } + } + private NetSocket handshake(HttpServerRequest req) { NetSocket so = req.netSocket(); try {