|
@@ -1,57 +0,0 @@
|
|
|
-"""train"""
|
|
|
-from datetime import datetime
|
|
|
-import os
|
|
|
-
|
|
|
-import streamlit as st
|
|
|
-
|
|
|
-from db import subject
|
|
|
-from db import train
|
|
|
-from components.remove_style import hide_footer
|
|
|
-import page_utils
|
|
|
-from device.peripheral.manager import get_serial_ports
|
|
|
-
|
|
|
-
|
|
|
-def _create_train(conn, subjects):
|
|
|
- with st.form("train_form"):
|
|
|
- st.write("创建训练")
|
|
|
- position = st.selectbox("训练部位", ['右手', '左手'])
|
|
|
- hand_com = st.selectbox("气动手COM口", get_serial_ports())
|
|
|
- finger_model = st.selectbox("气动手手势", list(page_utils.fingermodel_trans.keys()))
|
|
|
- trial_num = st.number_input("训练次数", value=10, step=5)
|
|
|
- owner_name = st.selectbox("用户", subjects.name.to_list())
|
|
|
- virtual_feedback_rate = st.number_input("假反馈比例", min_value=0., max_value=1., value=1., step=0.2)
|
|
|
- model_path = page_utils.file_selector(f'./static/models/{owner_name}')
|
|
|
- submitted = st.form_submit_button("开始训练")
|
|
|
- if submitted:
|
|
|
- start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
|
|
|
- train_new = {"position": position,
|
|
|
- "finger_model": page_utils.fingermodel_trans[finger_model],
|
|
|
- "trial_num": int(trial_num),
|
|
|
- "start_time": start_time,
|
|
|
- "owner_name": owner_name,
|
|
|
- "virtual_feedback_rate": float(virtual_feedback_rate),
|
|
|
- "model_path": model_path}
|
|
|
- train.create_train(conn, train_new)
|
|
|
- # run psychopy process
|
|
|
- os.system(f'python general_grasp_training.py --subj {owner_name} --n-trials {trial_num} --com {hand_com} --finger-model {page_utils.fingermodel_trans[finger_model]} --virtual-feedback-rate {virtual_feedback_rate} --model-path {str(model_path)}')
|
|
|
- return owner_name
|
|
|
-
|
|
|
-def render():
|
|
|
- st.set_page_config(
|
|
|
- page_title="手势训练", page_icon=":chart_with_upwards_trend:"
|
|
|
- )
|
|
|
- hide_footer()
|
|
|
-
|
|
|
- st.markdown("# 手势训练")
|
|
|
- st.sidebar.success("训练")
|
|
|
- conn = st.connection("sql_app", type="sql")
|
|
|
- train.create_table(conn)
|
|
|
- subjects = subject.get_subjects(conn)
|
|
|
- sub_name = _create_train(conn, subjects)
|
|
|
- if sub_name:
|
|
|
- trains = train.get_trains(conn, sub_name)
|
|
|
- st.write("# 训练列表")
|
|
|
- st.dataframe(trains)
|
|
|
-
|
|
|
-
|
|
|
-render()
|