|
|
@ -2,6 +2,7 @@ from flask import Flask, jsonify |
|
|
|
from flask_cors import CORS |
|
|
|
import threading |
|
|
|
import paramiko |
|
|
|
import json |
|
|
|
import time |
|
|
|
|
|
|
|
#region 全局 |
|
|
@ -9,6 +10,11 @@ import time |
|
|
|
app = Flask(__name__) |
|
|
|
CORS(app) |
|
|
|
port = 15002 |
|
|
|
server_list_path = 'serverList.json' |
|
|
|
data_list_lock = threading.Lock() |
|
|
|
check_interval = 2 |
|
|
|
# 共享list |
|
|
|
data_dict = dict() |
|
|
|
|
|
|
|
#endregion |
|
|
|
|
|
|
@ -21,8 +27,7 @@ def hello(): |
|
|
|
|
|
|
|
@app.route('/data', methods=['GET']) |
|
|
|
def get_data(): |
|
|
|
data = {'name': 'John', 'age': 25, 'city': 'New York'} |
|
|
|
return jsonify(data) |
|
|
|
return jsonify(get_all_data()) |
|
|
|
|
|
|
|
# 开始连接服务器 |
|
|
|
def connect_server(): |
|
|
@ -30,9 +35,132 @@ def connect_server(): |
|
|
|
|
|
|
|
#endregion |
|
|
|
|
|
|
|
def keep_check_one(server: dict, shared_data_list: dict, server_title: str, interval: float, re_connect_time: float=5): |
|
|
|
re_try_count = 0 |
|
|
|
# 循环连接 |
|
|
|
while True: |
|
|
|
try: |
|
|
|
# 建立SSH连接 |
|
|
|
client = paramiko.SSHClient() |
|
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) |
|
|
|
client.connect(server['ip'], port=server['port'], username=server['username'], password=server.get('password', None), key_filename=server.get('key_filename', None), timeout=interval*3) |
|
|
|
cmd = 'nvidia-smi --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory,temperature.gpu --format=csv' |
|
|
|
|
|
|
|
shared_data_list[server_title]['err_info'] = '' |
|
|
|
re_try_count = 0 |
|
|
|
|
|
|
|
# 循环检测 |
|
|
|
keep_run = True |
|
|
|
while keep_run: |
|
|
|
try: |
|
|
|
stdin, stdout, stderr = client.exec_command(cmd, timeout=interval*3) |
|
|
|
output = stdout.read().decode() |
|
|
|
output = output.split('\n') |
|
|
|
start_idx = 0 |
|
|
|
for i in range(len(output)): |
|
|
|
if output[i] == 'index, name, memory.total [MiB], memory.used [MiB], memory.free [MiB], utilization.gpu [%], utilization.memory [%], temperature.gpu': |
|
|
|
start_idx = i + 1 |
|
|
|
break |
|
|
|
output = output[start_idx:-1] |
|
|
|
# 解析数据 ----------------------------- |
|
|
|
result = [] |
|
|
|
for data in output: |
|
|
|
data_list = data.split(', ') |
|
|
|
idx = int(data_list[0]) |
|
|
|
gpu_name = data_list[1] |
|
|
|
total_mem = int(data_list[2].split(' ')[0]) |
|
|
|
used_mem = int(data_list[3].split(' ')[0]) |
|
|
|
free_mem = int(data_list[4].split(' ')[0]) |
|
|
|
util_gpu = int(data_list[5].split(' ')[0]) |
|
|
|
util_mem = int(data_list[6].split(' ')[0]) |
|
|
|
temperature = int(data_list[7]) |
|
|
|
|
|
|
|
# 简化GPU名称 |
|
|
|
if gpu_name.startswith('NVIDIA '): |
|
|
|
gpu_name = gpu_name[7:] |
|
|
|
if gpu_name.startswith('GeForce '): |
|
|
|
gpu_name = gpu_name[8:] |
|
|
|
|
|
|
|
result.append({ |
|
|
|
'idx': idx, |
|
|
|
'gpu_name': gpu_name, |
|
|
|
'total_mem': total_mem, |
|
|
|
'used_mem': used_mem, |
|
|
|
'free_mem': free_mem, |
|
|
|
'util_gpu': util_gpu, |
|
|
|
'util_mem': util_mem, |
|
|
|
'temperature': temperature |
|
|
|
}) |
|
|
|
|
|
|
|
# locked = False |
|
|
|
with data_list_lock: |
|
|
|
# locked = True |
|
|
|
shared_data_list[server_title]['info_list'] = result |
|
|
|
shared_data_list[server_title]['updated'] = True |
|
|
|
shared_data_list[server_title]['maxGPU'] = len(output) |
|
|
|
# locked = False |
|
|
|
except Exception as e: |
|
|
|
keep_run = False |
|
|
|
shared_data_list[server_title]['err_info'] = f'{e}' |
|
|
|
if 'info_list' in shared_data_list[server_title]: |
|
|
|
shared_data_list[server_title].pop('info_list') |
|
|
|
|
|
|
|
time.sleep(interval) |
|
|
|
|
|
|
|
# 关闭连接 |
|
|
|
client.close() |
|
|
|
except Exception as e: |
|
|
|
shared_data_list[server_title]['err_info'] = f'retry:{re_try_count}, {e}' |
|
|
|
time.sleep(re_connect_time) |
|
|
|
re_try_count += 1 |
|
|
|
|
|
|
|
# 获取所有的服务器数据 |
|
|
|
def get_all_data(): |
|
|
|
return filter_data(list(data_dict.keys())) |
|
|
|
|
|
|
|
# 根据key过滤所需的服务器数据 |
|
|
|
def filter_data(title_list: list): |
|
|
|
result = dict() |
|
|
|
for title in title_list: |
|
|
|
result[title] = {} |
|
|
|
# 不存在该title的数据 |
|
|
|
if title not in data_dict: |
|
|
|
result[title]['err_info'] = f'title \'{title}\' not exist!' |
|
|
|
continue |
|
|
|
# 还没获取到数据 |
|
|
|
info_list = data_dict[title].get('info_list', None) |
|
|
|
if info_list is None: |
|
|
|
result[title]['err_info'] = f'\'{title}\' still empty.' |
|
|
|
continue |
|
|
|
|
|
|
|
# 记录数据 |
|
|
|
data_updated = data_dict[title].get('updated', False) |
|
|
|
err_info = data_dict[title].get('err_info', '') |
|
|
|
result[title]['info_list'] = info_list |
|
|
|
result[title]['updated'] = data_updated |
|
|
|
result[title]['err_info'] = err_info |
|
|
|
return result |
|
|
|
|
|
|
|
def start_connect(): |
|
|
|
# 加载json |
|
|
|
with open(server_list_path, 'r') as f: |
|
|
|
server_list = json.load(f) |
|
|
|
|
|
|
|
global data_dict |
|
|
|
# 开启线程 |
|
|
|
for i, server_data in enumerate(server_list): |
|
|
|
data_dict[server_data['title']] = {} |
|
|
|
data_dict[server_data['title']]['server_data'] = server_data |
|
|
|
thread = threading.Thread(target=keep_check_one, args=(server_data, data_dict, server_data['title'], check_interval)) |
|
|
|
thread.daemon = True |
|
|
|
thread.start() |
|
|
|
|
|
|
|
print('start connect') |
|
|
|
|
|
|
|
# 测试 |
|
|
|
def test(): |
|
|
|
app.run(debug=True, port=port) |
|
|
|
start_connect() |
|
|
|
app.run(debug=True, host='127.0.0.1', port=port) |
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
test() |