基于spring-boot 2.x扩展WebSocket,支持细粒度控制
支持消息拦截的WebSocket,主要用于动态权限控制
项目地址
项目地址,欢迎提出宝贵意见。
spring security文档
spring security只支持在应用启动时加载WebSocket权限信息,修改权限必须要重启应用才能生效,不能按照用户的权限动态授权
spring提供的HandshakeInterceptor接口可以自定义拦截器,但是只能在握手是进行一次拦截,无法细粒度控制权限
StompSubProtocolHandler源码
package org.springframework.web.socket.messaging;
public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationEventPublisherAware {
/**
* 接收客户端消息方法
* Handle incoming WebSocket messages from clients.
*/
public void handleMessageFromClient(WebSocketSession session,
WebSocketMessage<?> webSocketMessage, MessageChannel outputChannel) {
List<Message<byte[]>> messages;
try {
ByteBuffer byteBuffer;
if (webSocketMessage instanceof TextMessage) {
byteBuffer = ByteBuffer.wrap(((TextMessage) webSocketMessage).asBytes());
}
else if (webSocketMessage instanceof BinaryMessage) {
byteBuffer = ((BinaryMessage) webSocketMessage).getPayload();
}
else {
return;
}
BufferingStompDecoder decoder = this.decoders.get(session.getId());
if (decoder == null) {
throw new IllegalStateException("No decoder for session id '" + session.getId() + "'");
}
messages = decoder.decode(byteBuffer);
if (messages.isEmpty()) {
if (logger.isTraceEnabled()) {
logger.trace("Incomplete STOMP frame content received in session " +
session + ", bufferSize=" + decoder.getBufferSize() +
", bufferSizeLimit=" + decoder.getBufferSizeLimit() + ".");
}
return;
}
}
catch (Throwable ex) {
if (logger.isErrorEnabled()) {
logger.error("Failed to parse " + webSocketMessage +
" in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
}
handleError(session, ex, null);
return;
}
for (Message<byte[]> message : messages) {
try {
StompHeaderAccessor headerAccessor =
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
headerAccessor.setSessionId(session.getId());
headerAccessor.setSessionAttributes(session.getAttributes());
headerAccessor.setUser(getUser(session));
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
if (!detectImmutableMessageInterceptor(outputChannel)) {
headerAccessor.setImmutable();
}
if (logger.isTraceEnabled()) {
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
}
StompCommand command = headerAccessor.getCommand();
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
if (isConnect) {
this.stats.incrementConnectCount();
}
else if (StompCommand.DISCONNECT.equals(command)) {
this.stats.incrementDisconnectCount();
}
try {
SimpAttributesContextHolder.setAttributesFromMessage(message);
boolean sent = outputChannel.send(message);
if (sent) {
if (isConnect) {
Principal user = headerAccessor.getUser();
if (user != null && user != session.getPrincipal()) {
this.stompAuthentications.put(session.getId(), user);
}
}
if (this.eventPublisher != null) {
Principal user = getUser(session);
if (isConnect) {
publishEvent(this.eventPublisher, new SessionConnectEvent(this, message, user));
}
else if (StompCommand.SUBSCRIBE.equals(command)) {
publishEvent(this.eventPublisher, new SessionSubscribeEvent(this, message, user));
}
else if (StompCommand.UNSUBSCRIBE.equals(command)) {
publishEvent(this.eventPublisher, new SessionUnsubscribeEvent(this, message, user));
}
}
}
}
finally {
SimpAttributesContextHolder.resetAttributes();
}
}
catch (Throwable ex) {
if (logger.isErrorEnabled()) {
logger.error("Failed to send client message to application via MessageChannel" +
" in session " + session.getId() + ". Sending STOMP ERROR to client.", ex);
}
handleError(session, ex, message);
}
}
}
/**
* 向客户端发送消息的方法
* Handle STOMP messages going back out to WebSocket clients.
*/
@Override
@SuppressWarnings("unchecked")
public void handleMessageToClient(WebSocketSession session, Message<?> message) {
if (!(message.getPayload() instanceof byte[])) {
if (logger.isErrorEnabled()) {
logger.error("Expected byte[] payload. Ignoring " + message + ".");
}
return;
}
StompHeaderAccessor accessor = getStompHeaderAccessor(message);
StompCommand command = accessor.getCommand();
if (StompCommand.MESSAGE.equals(command)) {
if (accessor.getSubscriptionId() == null && logger.isWarnEnabled()) {
logger.warn("No STOMP \"subscription\" header in " + message);
}
String origDestination = accessor.getFirstNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
if (origDestination != null) {
accessor = toMutableAccessor(accessor, message);
accessor.removeNativeHeader(SimpMessageHeaderAccessor.ORIGINAL_DESTINATION);
accessor.setDestination(origDestination);
}
}
else if (StompCommand.CONNECTED.equals(command)) {
this.stats.incrementConnectedCount();
accessor = afterStompSessionConnected(message, accessor, session);
if (this.eventPublisher != null) {
try {
SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
SimpAttributesContextHolder.setAttributes(simpAttributes);
Principal user = getUser(session);
publishEvent(this.eventPublisher, new SessionConnectedEvent(this, (Message<byte[]>) message, user));
}
finally {
SimpAttributesContextHolder.resetAttributes();
}
}
}
byte[] payload = (byte[]) message.getPayload();
if (StompCommand.ERROR.equals(command) && getErrorHandler() != null) {
Message<byte[]> errorMessage = getErrorHandler().handleErrorMessageToClient((Message<byte[]>) message);
if (errorMessage != null) {
accessor = MessageHeaderAccessor.getAccessor(errorMessage, StompHeaderAccessor.class);
Assert.state(accessor != null, "No StompHeaderAccessor");
payload = errorMessage.getPayload();
}
}
sendToClient(session, accessor, payload);
}
}
阅读源码可以发现spring原生的消息处理类不支持自定义拦截器
interceptable-websocket
做这个项目是为了细粒度的动态控制WebSocket的权限,项目对StompSubProtocolHandler类和其他相关的类做了扩展,增加了对自定义拦截器的支持
具体实现在extension包,拦截器实现在interceptor包
使用方法
直接使用
项目已经发布到maven中央仓库,直接在pom.xml中引用即可
<dependencies>
<dependency>
<groupId>com.xzixi</groupId>
<artifactId>interceptable-websocket</artifactId>
<version>1.0</version>
</dependency>
</dependencies>
修改后使用
- 下载项目
打开git bash窗口,执行命令git clone git@gitee.com:xuelingkang/websocket.git
- 编译并安装到本地maven仓库
进入工程目录,打开cmd窗口,执行命令mvn clean install
- 在自己的项目中引用
<dependencies>
<dependency>
<groupId>com.xzixi</groupId>
<artifactId>interceptable-websocket</artifactId>
<version>1.0</version>
</dependency>
</dependencies>
配置类:
package com.xzixi.websocket.interceptablewebsocketdemo.config;
@Configuration
@EnableInterceptableWebSocketMessageBroker // 增加注解
public class InterceptableSecurityWebSocketConfig extends AbstractInterceptableSecurityWebSocketMessageBrokerConfigurer {
@Override
public void registerStompEndpoints(InterceptableWebMvcStompEndpointRegistry registry) {
// 注册拦截器
registry.addFromClientInterceptor(accessDecisionFromClientInterceptor()) // 消息授权决策
.addFromClientInterceptor(sessionIdUnRegistryInterceptor()) // sessionId记录
.addToClientInterceptor(sessionIdRegistryInterceptor()); // sessionId移除
}
}
具体使用方法请参考案例工程:interceptable-websocket-demo