You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
237 lines
8.0 KiB
237 lines
8.0 KiB
from flask import Flask, jsonify
|
|
from datetime import datetime
|
|
from flask_cors import CORS
|
|
import threading
|
|
import paramiko
|
|
import json
|
|
import time
|
|
|
|
#region 全局
|
|
|
|
app = Flask(__name__)
|
|
CORS(app)
|
|
port = 15002
|
|
server_list_path = 'serverList.json'
|
|
data_list_lock = threading.Lock()
|
|
check_interval = 2
|
|
# 共享list
|
|
data_dict = dict()
|
|
|
|
#endregion
|
|
|
|
#region 接口
|
|
|
|
# 测试用
|
|
@app.route('/')
|
|
def hello():
|
|
return 'hi. —— CheckGPUsWeb'
|
|
|
|
@app.route('/all_data', methods=['GET'])
|
|
def get_data():
|
|
return jsonify(get_all_data())
|
|
|
|
# 开始连接服务器
|
|
def connect_server():
|
|
pass
|
|
|
|
#endregion
|
|
|
|
def get_gpus_info(client, timeout):
|
|
try:
|
|
cmd = 'nvidia-smi --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory,temperature.gpu --format=csv'
|
|
|
|
stdin, stdout, stderr = client.exec_command(cmd, timeout=timeout)
|
|
output = stdout.read().decode()
|
|
output = output.split('\n')
|
|
start_idx = 0
|
|
for i in range(len(output)):
|
|
if output[i] == 'index, name, memory.total [MiB], memory.used [MiB], memory.free [MiB], utilization.gpu [%], utilization.memory [%], temperature.gpu':
|
|
start_idx = i + 1
|
|
break
|
|
output = output[start_idx:-1]
|
|
# 解析数据 -----------------------------
|
|
result = []
|
|
for data in output:
|
|
data_list = data.split(', ')
|
|
idx = int(data_list[0])
|
|
gpu_name = data_list[1]
|
|
total_mem = int(data_list[2].split(' ')[0])
|
|
used_mem = int(data_list[3].split(' ')[0])
|
|
free_mem = int(data_list[4].split(' ')[0])
|
|
util_gpu = int(data_list[5].split(' ')[0])
|
|
util_mem = int(data_list[6].split(' ')[0])
|
|
temperature = int(data_list[7])
|
|
|
|
# 简化GPU名称
|
|
if gpu_name.startswith('NVIDIA '):
|
|
gpu_name = gpu_name[7:]
|
|
if gpu_name.startswith('GeForce '):
|
|
gpu_name = gpu_name[8:]
|
|
|
|
result.append({
|
|
'idx': idx,
|
|
'gpu_name': gpu_name,
|
|
'total_mem': total_mem,
|
|
'used_mem': used_mem,
|
|
'free_mem': free_mem,
|
|
'util_gpu': util_gpu,
|
|
'util_mem': util_mem,
|
|
'temperature': temperature
|
|
})
|
|
|
|
return result
|
|
except Exception as e:
|
|
None
|
|
|
|
def get_storage_info(client, timeout, path_list):
|
|
try:
|
|
result = []
|
|
for target_path in path_list:
|
|
stdin, stdout, stderr = client.exec_command(f'df {target_path} | grep \'{target_path}\'', timeout=timeout)
|
|
output = stdout.read().decode()
|
|
if output == "":
|
|
continue
|
|
data = output.split()
|
|
tmp_res = {
|
|
"path": target_path,
|
|
"total": int(data[1]),
|
|
"available": int(data[3])
|
|
}
|
|
result.append(tmp_res)
|
|
return result
|
|
except Exception as e:
|
|
return None
|
|
|
|
def get_memory_info(client, timeout):
|
|
try:
|
|
stdin, stdout, stderr = client.exec_command('free', timeout=timeout)
|
|
output = stdout.read().decode().split('\n')[1]
|
|
if output == "":
|
|
return None
|
|
data = output.split()
|
|
result = {
|
|
"total": int(data[1]),
|
|
"used": int(data[2])
|
|
}
|
|
|
|
return result
|
|
except Exception as e:
|
|
return None
|
|
|
|
# 持续获取一个服务器的信息
|
|
def keep_check_one(server: dict, shared_data_list: dict, server_title: str, interval: float, re_connect_time: float=5):
|
|
# 处理一下需要检查的存储空间路径
|
|
if not 'storage_list' in server:
|
|
server['storage_list'] = []
|
|
if not '/' in server['storage_list']:
|
|
server['storage_list'].insert(0, '/')
|
|
|
|
re_try_count = 0
|
|
# 循环连接
|
|
while True:
|
|
try:
|
|
# 建立SSH连接
|
|
client = paramiko.SSHClient()
|
|
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
client.connect(server['ip'], port=server['port'], username=server['username'], password=server.get('password', None), key_filename=server.get('key_filename', None), timeout=interval*3)
|
|
cmd = 'nvidia-smi --query-gpu=index,name,memory.total,memory.used,memory.free,utilization.gpu,utilization.memory,temperature.gpu --format=csv'
|
|
|
|
shared_data_list[server_title]['err_info'] = None
|
|
re_try_count = 0
|
|
|
|
# 循环检测
|
|
keep_run = True
|
|
while keep_run:
|
|
try:
|
|
# gpu 信息
|
|
gpu_info = get_gpus_info(client, interval*3)
|
|
# 存储空间信息
|
|
storage_info = get_storage_info(client, interval*3, server['storage_list'])
|
|
# 内存信息
|
|
memory_info = get_memory_info(client, interval*3)
|
|
|
|
# 记录信息
|
|
with data_list_lock:
|
|
shared_data_list[server_title]['gpu_info_list'] = gpu_info
|
|
shared_data_list[server_title]['storage_info_list'] = storage_info
|
|
shared_data_list[server_title]['memory_info'] = memory_info
|
|
shared_data_list[server_title]['updated'] = True
|
|
shared_data_list[server_title]['maxGPU'] = len(gpu_info)
|
|
|
|
except Exception as e:
|
|
keep_run = False
|
|
shared_data_list[server_title]['err_info'] = f'{e}'
|
|
if 'gpu_info_list' in shared_data_list[server_title]:
|
|
shared_data_list[server_title].pop('gpu_info_list')
|
|
|
|
time.sleep(interval)
|
|
|
|
# 关闭连接
|
|
client.close()
|
|
except Exception as e:
|
|
shared_data_list[server_title]['err_info'] = f'retry:{re_try_count}, {e}'
|
|
time.sleep(re_connect_time)
|
|
re_try_count += 1
|
|
|
|
# 获取所有的服务器数据
|
|
def get_all_data():
|
|
return filter_data(list(data_dict.keys()))
|
|
|
|
# 根据key过滤所需的服务器数据
|
|
def filter_data(title_list: list):
|
|
result = dict()
|
|
server_data = dict()
|
|
for title in title_list:
|
|
server_data[title] = {}
|
|
# 不存在该title的数据
|
|
if title not in data_dict:
|
|
server_data[title]['err_info'] = f'title \'{title}\' not exist!'
|
|
continue
|
|
|
|
# 记录数据 ----------------------------------------------------
|
|
data_updated = data_dict[title].get('updated', False)
|
|
# 是否更新
|
|
server_data[title]['updated'] = data_updated
|
|
# 报错信息
|
|
err_info = data_dict[title].get('err_info', None)
|
|
if err_info is not None:
|
|
server_data[title]['err_info'] = err_info
|
|
# 显卡
|
|
gpu_info_list = data_dict[title].get('gpu_info_list', None)
|
|
if gpu_info_list is not None:
|
|
server_data[title]['gpu_info_list'] = gpu_info_list
|
|
# 硬盘
|
|
storage_info_list = data_dict[title].get('storage_info_list', None)
|
|
if storage_info_list is not None:
|
|
server_data[title]['storage_info_list'] = storage_info_list
|
|
# 内存
|
|
memory_info = data_dict[title].get('memory_info', None)
|
|
if memory_info is not None:
|
|
server_data[title]['memory_info'] = memory_info
|
|
result['time'] = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
result['server_data'] = server_data
|
|
return result
|
|
|
|
def start_connect():
|
|
# 加载json
|
|
with open(server_list_path, 'r') as f:
|
|
server_list = json.load(f)
|
|
|
|
global data_dict
|
|
# 开启线程
|
|
for i, server_data in enumerate(server_list):
|
|
data_dict[server_data['title']] = {}
|
|
data_dict[server_data['title']]['server_data'] = server_data
|
|
thread = threading.Thread(target=keep_check_one, args=(server_data, data_dict, server_data['title'], check_interval))
|
|
thread.daemon = True
|
|
thread.start()
|
|
|
|
print('start connect')
|
|
|
|
# 测试
|
|
def test():
|
|
start_connect()
|
|
app.run(debug=True, host='127.0.0.1', port=port)
|
|
|
|
if __name__ == '__main__':
|
|
test()
|