from flask import Flask, jsonify from datetime import datetime from flask_cors import CORS import threading import paramiko import json import time #region 全局 app = Flask(__name__) CORS(app) port = 15002 server_list_path = 'serverList.json' data_list_lock = threading.Lock() check_interval = 2 # 共享list data_dict = dict() #endregion #region 接口 # 测试用 @app.route('/') def hello(): return 'hi. —— CheckGPUsWeb' @app.route('/all_data', methods=['GET']) def get_data(): return jsonify(get_all_data()) # 开始连接服务器 def connect_server(): pass #endregion def keep_check_one(server: dict, shared_data_list: dict, server_title: str, interval: float, re_connect_time: float=5): re_try_count = 0 # 循环连接 while True: try: # 建立SSH连接 client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) client.connect(server['ip'], port=server['port'], username=server['username'], password=server.get('password', None), key_filename=server.get('key_filename', None), timeout=interval*3) cmd = 'nvidia-smi --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory,temperature.gpu --format=csv' shared_data_list[server_title]['err_info'] = '' re_try_count = 0 # 循环检测 keep_run = True while keep_run: try: stdin, stdout, stderr = client.exec_command(cmd, timeout=interval*3) output = stdout.read().decode() output = output.split('\n') start_idx = 0 for i in range(len(output)): if output[i] == 'index, name, memory.total [MiB], memory.used [MiB], memory.free [MiB], utilization.gpu [%], utilization.memory [%], temperature.gpu': start_idx = i + 1 break output = output[start_idx:-1] # 解析数据 ----------------------------- result = [] for data in output: data_list = data.split(', ') idx = int(data_list[0]) gpu_name = data_list[1] total_mem = int(data_list[2].split(' ')[0]) used_mem = int(data_list[3].split(' ')[0]) free_mem = int(data_list[4].split(' ')[0]) util_gpu = int(data_list[5].split(' ')[0]) util_mem = int(data_list[6].split(' ')[0]) temperature = int(data_list[7]) # 简化GPU名称 if gpu_name.startswith('NVIDIA '): gpu_name = gpu_name[7:] if gpu_name.startswith('GeForce '): gpu_name = gpu_name[8:] result.append({ 'idx': idx, 'gpu_name': gpu_name, 'total_mem': total_mem, 'used_mem': used_mem, 'free_mem': free_mem, 'util_gpu': util_gpu, 'util_mem': util_mem, 'temperature': temperature }) # locked = False with data_list_lock: # locked = True shared_data_list[server_title]['info_list'] = result shared_data_list[server_title]['updated'] = True shared_data_list[server_title]['maxGPU'] = len(output) # locked = False except Exception as e: keep_run = False shared_data_list[server_title]['err_info'] = f'{e}' if 'info_list' in shared_data_list[server_title]: shared_data_list[server_title].pop('info_list') time.sleep(interval) # 关闭连接 client.close() except Exception as e: shared_data_list[server_title]['err_info'] = f'retry:{re_try_count}, {e}' time.sleep(re_connect_time) re_try_count += 1 # 获取所有的服务器数据 def get_all_data(): return filter_data(list(data_dict.keys())) # 根据key过滤所需的服务器数据 def filter_data(title_list: list): result = dict() server_data = dict() for title in title_list: server_data[title] = {} # 不存在该title的数据 if title not in data_dict: server_data[title]['err_info'] = f'title \'{title}\' not exist!' continue # 还没获取到数据 info_list = data_dict[title].get('info_list', None) if info_list is None: err_info = data_dict[title].get('err_info', None) if err_info is not None: server_data[title]['err_info'] = data_dict[title]['err_info'] else: server_data[title]['err_info'] = f'\'{title}\' still empty.' continue # 记录数据 data_updated = data_dict[title].get('updated', False) err_info = data_dict[title].get('err_info', '') server_data[title]['info_list'] = info_list server_data[title]['updated'] = data_updated server_data[title]['err_info'] = err_info result['time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S') result['server_data'] = server_data return result def start_connect(): # 加载json with open(server_list_path, 'r') as f: server_list = json.load(f) global data_dict # 开启线程 for i, server_data in enumerate(server_list): data_dict[server_data['title']] = {} data_dict[server_data['title']]['server_data'] = server_data thread = threading.Thread(target=keep_check_one, args=(server_data, data_dict, server_data['title'], check_interval)) thread.daemon = True thread.start() print('start connect') # 测试 def test(): start_connect() app.run(debug=True, host='127.0.0.1', port=port) if __name__ == '__main__': test()