From 8b268438511b7d79ea3b6bd39d9ee04f45c0aa70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=B1=BC=E9=AA=A8=E5=89=AA?= <1580622474@qq.com> Date: Thu, 8 Aug 2024 21:58:28 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E6=AD=A5=E5=AE=9E=E7=8E=B0=E8=8E=B7?= =?UTF-8?q?=E5=8F=96=E6=9C=8D=E5=8A=A1=E5=99=A8GPU=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 1 + app.py | 134 +++++++++++++++++++++++++++++++++++++++- index.html | 2 +- serverList_examlpe.json | 23 +++++++ 4 files changed, 156 insertions(+), 4 deletions(-) create mode 100644 .gitignore create mode 100644 serverList_examlpe.json diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..597871a --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +serverList.json \ No newline at end of file diff --git a/app.py b/app.py index 512042a..d4852ee 100644 --- a/app.py +++ b/app.py @@ -2,6 +2,7 @@ from flask import Flask, jsonify from flask_cors import CORS import threading import paramiko +import json import time #region 全局 @@ -9,6 +10,11 @@ import time app = Flask(__name__) CORS(app) port = 15002 +server_list_path = 'serverList.json' +data_list_lock = threading.Lock() +check_interval = 2 +# 共享list +data_dict = dict() #endregion @@ -21,8 +27,7 @@ def hello(): @app.route('/data', methods=['GET']) def get_data(): - data = {'name': 'John', 'age': 25, 'city': 'New York'} - return jsonify(data) + return jsonify(get_all_data()) # 开始连接服务器 def connect_server(): @@ -30,9 +35,132 @@ def connect_server(): #endregion +def keep_check_one(server: dict, shared_data_list: dict, server_title: str, interval: float, re_connect_time: float=5): + re_try_count = 0 + # 循环连接 + while True: + try: + # 建立SSH连接 + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + client.connect(server['ip'], port=server['port'], username=server['username'], password=server.get('password', None), key_filename=server.get('key_filename', None), timeout=interval*3) + cmd = 'nvidia-smi --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory,temperature.gpu --format=csv' + + shared_data_list[server_title]['err_info'] = '' + re_try_count = 0 + + # 循环检测 + keep_run = True + while keep_run: + try: + stdin, stdout, stderr = client.exec_command(cmd, timeout=interval*3) + output = stdout.read().decode() + output = output.split('\n') + start_idx = 0 + for i in range(len(output)): + if output[i] == 'index, name, memory.total [MiB], memory.used [MiB], memory.free [MiB], utilization.gpu [%], utilization.memory [%], temperature.gpu': + start_idx = i + 1 + break + output = output[start_idx:-1] + # 解析数据 ----------------------------- + result = [] + for data in output: + data_list = data.split(', ') + idx = int(data_list[0]) + gpu_name = data_list[1] + total_mem = int(data_list[2].split(' ')[0]) + used_mem = int(data_list[3].split(' ')[0]) + free_mem = int(data_list[4].split(' ')[0]) + util_gpu = int(data_list[5].split(' ')[0]) + util_mem = int(data_list[6].split(' ')[0]) + temperature = int(data_list[7]) + + # 简化GPU名称 + if gpu_name.startswith('NVIDIA '): + gpu_name = gpu_name[7:] + if gpu_name.startswith('GeForce '): + gpu_name = gpu_name[8:] + + result.append({ + 'idx': idx, + 'gpu_name': gpu_name, + 'total_mem': total_mem, + 'used_mem': used_mem, + 'free_mem': free_mem, + 'util_gpu': util_gpu, + 'util_mem': util_mem, + 'temperature': temperature + }) + + # locked = False + with data_list_lock: + # locked = True + shared_data_list[server_title]['info_list'] = result + shared_data_list[server_title]['updated'] = True + shared_data_list[server_title]['maxGPU'] = len(output) + # locked = False + except Exception as e: + keep_run = False + shared_data_list[server_title]['err_info'] = f'{e}' + if 'info_list' in shared_data_list[server_title]: + shared_data_list[server_title].pop('info_list') + + time.sleep(interval) + + # 关闭连接 + client.close() + except Exception as e: + shared_data_list[server_title]['err_info'] = f'retry:{re_try_count}, {e}' + time.sleep(re_connect_time) + re_try_count += 1 + +# 获取所有的服务器数据 +def get_all_data(): + return filter_data(list(data_dict.keys())) + +# 根据key过滤所需的服务器数据 +def filter_data(title_list: list): + result = dict() + for title in title_list: + result[title] = {} + # 不存在该title的数据 + if title not in data_dict: + result[title]['err_info'] = f'title \'{title}\' not exist!' + continue + # 还没获取到数据 + info_list = data_dict[title].get('info_list', None) + if info_list is None: + result[title]['err_info'] = f'\'{title}\' still empty.' + continue + + # 记录数据 + data_updated = data_dict[title].get('updated', False) + err_info = data_dict[title].get('err_info', '') + result[title]['info_list'] = info_list + result[title]['updated'] = data_updated + result[title]['err_info'] = err_info + return result + +def start_connect(): + # 加载json + with open(server_list_path, 'r') as f: + server_list = json.load(f) + + global data_dict + # 开启线程 + for i, server_data in enumerate(server_list): + data_dict[server_data['title']] = {} + data_dict[server_data['title']]['server_data'] = server_data + thread = threading.Thread(target=keep_check_one, args=(server_data, data_dict, server_data['title'], check_interval)) + thread.daemon = True + thread.start() + + print('start connect') + # 测试 def test(): - app.run(debug=True, port=port) + start_connect() + app.run(debug=True, host='127.0.0.1', port=port) if __name__ == '__main__': test() \ No newline at end of file diff --git a/index.html b/index.html index 3312985..7e20b05 100644 --- a/index.html +++ b/index.html @@ -12,7 +12,7 @@