You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
166 lines
5.8 KiB
166 lines
5.8 KiB
from flask import Flask, jsonify
|
|
from flask_cors import CORS
|
|
import threading
|
|
import paramiko
|
|
import json
|
|
import time
|
|
|
|
#region 全局
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
port = 15002
|
|
server_list_path = 'serverList.json'
|
|
data_list_lock = threading.Lock()
|
|
check_interval = 2
|
|
# 共享list
|
|
data_dict = dict()
|
|
|
|
#endregion
|
|
|
|
#region 接口
|
|
|
|
# 测试用
|
|
@app.route('/')
|
|
def hello():
|
|
return 'hi. —— CheckGPUsWeb'
|
|
|
|
@app.route('/data', methods=['GET'])
|
|
def get_data():
|
|
return jsonify(get_all_data())
|
|
|
|
# 开始连接服务器
|
|
def connect_server():
|
|
pass
|
|
|
|
#endregion
|
|
|
|
def keep_check_one(server: dict, shared_data_list: dict, server_title: str, interval: float, re_connect_time: float=5):
|
|
re_try_count = 0
|
|
# 循环连接
|
|
while True:
|
|
try:
|
|
# 建立SSH连接
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
client.connect(server['ip'], port=server['port'], username=server['username'], password=server.get('password', None), key_filename=server.get('key_filename', None), timeout=interval*3)
|
|
cmd = 'nvidia-smi --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory,temperature.gpu --format=csv'
|
|
|
|
shared_data_list[server_title]['err_info'] = ''
|
|
re_try_count = 0
|
|
|
|
# 循环检测
|
|
keep_run = True
|
|
while keep_run:
|
|
try:
|
|
stdin, stdout, stderr = client.exec_command(cmd, timeout=interval*3)
|
|
output = stdout.read().decode()
|
|
output = output.split('\n')
|
|
start_idx = 0
|
|
for i in range(len(output)):
|
|
if output[i] == 'index, name, memory.total [MiB], memory.used [MiB], memory.free [MiB], utilization.gpu [%], utilization.memory [%], temperature.gpu':
|
|
start_idx = i + 1
|
|
break
|
|
output = output[start_idx:-1]
|
|
# 解析数据 -----------------------------
|
|
result = []
|
|
for data in output:
|
|
data_list = data.split(', ')
|
|
idx = int(data_list[0])
|
|
gpu_name = data_list[1]
|
|
total_mem = int(data_list[2].split(' ')[0])
|
|
used_mem = int(data_list[3].split(' ')[0])
|
|
free_mem = int(data_list[4].split(' ')[0])
|
|
util_gpu = int(data_list[5].split(' ')[0])
|
|
util_mem = int(data_list[6].split(' ')[0])
|
|
temperature = int(data_list[7])
|
|
|
|
# 简化GPU名称
|
|
if gpu_name.startswith('NVIDIA '):
|
|
gpu_name = gpu_name[7:]
|
|
if gpu_name.startswith('GeForce '):
|
|
gpu_name = gpu_name[8:]
|
|
|
|
result.append({
|
|
'idx': idx,
|
|
'gpu_name': gpu_name,
|
|
'total_mem': total_mem,
|
|
'used_mem': used_mem,
|
|
'free_mem': free_mem,
|
|
'util_gpu': util_gpu,
|
|
'util_mem': util_mem,
|
|
'temperature': temperature
|
|
})
|
|
|
|
# locked = False
|
|
with data_list_lock:
|
|
# locked = True
|
|
shared_data_list[server_title]['info_list'] = result
|
|
shared_data_list[server_title]['updated'] = True
|
|
shared_data_list[server_title]['maxGPU'] = len(output)
|
|
# locked = False
|
|
except Exception as e:
|
|
keep_run = False
|
|
shared_data_list[server_title]['err_info'] = f'{e}'
|
|
if 'info_list' in shared_data_list[server_title]:
|
|
shared_data_list[server_title].pop('info_list')
|
|
|
|
time.sleep(interval)
|
|
|
|
# 关闭连接
|
|
client.close()
|
|
except Exception as e:
|
|
shared_data_list[server_title]['err_info'] = f'retry:{re_try_count}, {e}'
|
|
time.sleep(re_connect_time)
|
|
re_try_count += 1
|
|
|
|
# 获取所有的服务器数据
|
|
def get_all_data():
|
|
return filter_data(list(data_dict.keys()))
|
|
|
|
# 根据key过滤所需的服务器数据
|
|
def filter_data(title_list: list):
|
|
result = dict()
|
|
for title in title_list:
|
|
result[title] = {}
|
|
# 不存在该title的数据
|
|
if title not in data_dict:
|
|
result[title]['err_info'] = f'title \'{title}\' not exist!'
|
|
continue
|
|
# 还没获取到数据
|
|
info_list = data_dict[title].get('info_list', None)
|
|
if info_list is None:
|
|
result[title]['err_info'] = f'\'{title}\' still empty.'
|
|
continue
|
|
|
|
# 记录数据
|
|
data_updated = data_dict[title].get('updated', False)
|
|
err_info = data_dict[title].get('err_info', '')
|
|
result[title]['info_list'] = info_list
|
|
result[title]['updated'] = data_updated
|
|
result[title]['err_info'] = err_info
|
|
return result
|
|
|
|
def start_connect():
|
|
# 加载json
|
|
with open(server_list_path, 'r') as f:
|
|
server_list = json.load(f)
|
|
|
|
global data_dict
|
|
# 开启线程
|
|
for i, server_data in enumerate(server_list):
|
|
data_dict[server_data['title']] = {}
|
|
data_dict[server_data['title']]['server_data'] = server_data
|
|
thread = threading.Thread(target=keep_check_one, args=(server_data, data_dict, server_data['title'], check_interval))
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
print('start connect')
|
|
|
|
# 测试
|
|
def test():
|
|
start_connect()
|
|
app.run(debug=True, host='127.0.0.1', port=port)
|
|
|
|
if __name__ == '__main__':
|
|
test()
|