package com.educoder.bridge.service; import com.alibaba.fastjson.JSONObject; import com.educoder.bridge.model.SSHInfo; import com.educoder.bridge.model.SSHSession; import com.educoder.bridge.utils.Base64Util; import com.jcraft.jsch.ChannelShell; import com.jcraft.jsch.JSch; import com.jcraft.jsch.Session; import com.jcraft.jsch.UserInfo; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @Service public class JchService { private static List sshSessionQueue = new CopyOnWriteArrayList<>(); private ExecutorService executorService = Executors.newCachedThreadPool(); private Logger logger = LoggerFactory.getLogger(getClass()); com.jcraft.jsch.Logger jschLogger = new com.jcraft.jsch.Logger() { @Override public boolean isEnabled(int arg0) { return true; } @Override public void log(int arg0, String arg1) { if (logger.isTraceEnabled()) { logger.trace("JSch Log [Level " + arg0 + "]: " + arg1); } } }; /** * 在webSocket连接时,初始化一个ssh连接 * * @param webSocketSession webSocket连接 */ public void add(WebSocketSession webSocketSession) { SSHSession sshSession = new SSHSession(); sshSession.setWebSocketSession(webSocketSession); sshSessionQueue.add(sshSession); } /** * 处理客户端发过来的数据 * @param buffer 数据 * @param webSocketSession webSocket连接 */ public void recv(String buffer, WebSocketSession webSocketSession) { SSHSession sshSession = null; try { logger.debug("webSocketSessionID: {}, 信息: {}", webSocketSession.getId(), buffer); JSONObject info = JSONObject.parseObject(buffer); String tp = info.getString("tp"); sshSession = findByWebSocketSession(webSocketSession); //初始化连接 if ("init".equals(tp)) { // {"tp":"init","data":{"host":"127.0.0.1","port":"41080","username":"root","password":"123123"}} SSHInfo sshInfo = info.getObject("data", SSHInfo.class); sshSession.setSSHInfo(sshInfo); if (sshSession != null) { SSHSession finalSSHSession = sshSession; // 新开一个线程建立连接,连接开启之后以一直监听来自客户端的输入 executorService.execute(() -> { connectTossh(finalSSHSession); }); } } else if ("client".equals(tp)) { String data = info.getString("data"); // 将网页输入的数据传送给后端服务器 if (sshSession != null) { transTossh(sshSession.getOutputStream(), data); } } } catch (Exception e) { logger.error("转发命令到ssh出错: {}", e); close(sshSession); } } /** * 将数据传送给服务端作为SSH的输入 * * @param outputStream * @param data * @throws IOException */ private void transTossh(OutputStream outputStream, String data) throws IOException { if (outputStream != null) { outputStream.write(data.getBytes()); outputStream.flush(); } } /** * 连接ssh * * @param sshSession ssh连接需要的信息 */ private void connectTossh(SSHSession sshSession){ Session jschSession = null; SSHInfo SSHInfo = sshSession.getSSHInfo(); try { JSch jsch = new JSch(); JSch.setLogger(jschLogger); //启动线程 java.util.Properties config = new java.util.Properties(); config.put("StrictHostKeyChecking", "no"); jschSession = jsch.getSession(SSHInfo.getUsername(), SSHInfo.getHost(), SSHInfo.getPort()); jschSession.setConfig(config); jschSession.setPassword(SSHInfo.getPassword()); jschSession.setUserInfo(new UserInfo() { @Override public String getPassphrase() { return null; } @Override public String getPassword() { return null; } @Override public boolean promptPassword(String s) { return false; } @Override public boolean promptPassphrase(String s) { return false; } @Override public boolean promptYesNo(String s) { return true; } // Accept all server keys @Override public void showMessage(String s) { } }); jschSession.connect(); ChannelShell channel = (ChannelShell) jschSession.openChannel("shell"); channel.setPtyType("xterm"); channel.connect(); sshSession.setChannel(channel); InputStream inputStream = channel.getInputStream(); sshSession.setOutputStream(channel.getOutputStream()); sshSession.setSSHInfo(SSHInfo); logger.debug("主机: {} 连接成功!", SSHInfo.getHost()); // 循环读取,jsch的输入为服务器执行命令之后的返回数据 byte[] buf = new byte[1024]; while (true) { int length = inputStream.read(buf); if (length < 0) { close(sshSession); throw new Exception("读取出错,数据长度:" + length); } sendMsg(sshSession.getWebSocketSession(), Arrays.copyOfRange(buf, 0, length)); } } catch (Exception e) { logger.error("ssh连接出错, e: {}", e); } finally { logger.info("连接关闭, {}", SSHInfo.getHost()); if (jschSession != null) { jschSession.disconnect(); } close(sshSession); } } /** * 发送数据回websocket * * @param webSocketSession webSocket连接 * @param buffer 数据 * @throws IOException */ public void sendMsg(WebSocketSession webSocketSession, byte[] buffer) throws IOException { logger.debug("服务端返回的数据: {}", new String(buffer, "UTF-8")); webSocketSession.sendMessage(new TextMessage(Base64Util.encodeBytes(buffer))); } /** * 通过webSocket连接在队列中找到对应的SSH连接 * * @param webSocketSession webSocket连接 */ public SSHSession findByWebSocketSession(WebSocketSession webSocketSession) { Optional optional = sshSessionQueue.stream().filter(webscoketObj -> webscoketObj.getWebSocketSession() == webSocketSession).findFirst(); if (optional.isPresent()) { return optional.get(); } return null; } /** * 关闭ssh和websocket连接 * * @param sshSession ssh连接 */ private void close(SSHSession sshSession) { if (sshSession != null) { sshSession.getChannel().disconnect(); try { sshSession.getWebSocketSession().close(); sshSession.getOutputStream().close(); } catch (IOException e) { logger.error("连接关闭失败!e: {}", e); } sshSessionQueue.remove(sshSession); } } /** * 通过webSocketSession关闭ssh与webSocket连接 * * @param webSocketSession */ public void closeByWebSocket(WebSocketSession webSocketSession) { close(findByWebSocketSession(webSocketSession)); } }