|
@@ -7,31 +7,39 @@ import streamlit as st
|
|
|
from db.models import subject
|
|
|
from db.models import train
|
|
|
from components.remove_style import hide_footer
|
|
|
+import page_utils
|
|
|
|
|
|
|
|
|
def _create_train(conn, subjects):
|
|
|
with st.form("train_form"):
|
|
|
st.write("创建训练")
|
|
|
- # TODO: redefine parameters
|
|
|
- position = st.text_input("训练部位")
|
|
|
- trial_num = st.number_input("训练次数", value=1, step=1)
|
|
|
+ position = st.selectbox("训练部位", ['左手', '右手'])
|
|
|
+ finger_model = st.selectbox("气动手手势", list(page_utils.fingermodel_trans.keys()))
|
|
|
+ trial_num = st.number_input("训练次数", value=10, step=5)
|
|
|
owner_name = st.selectbox("用户", subjects.name.to_list())
|
|
|
+ virtual_feedback_rate = st.number_input("假反馈比例", value=0., step=0.2)
|
|
|
+ model_path = page_utils.file_selector(os.path.join(f'./model/{owner_name}'))
|
|
|
submitted = st.form_submit_button("开始训练")
|
|
|
if submitted:
|
|
|
start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
|
|
|
- train_new = {"position": position, "trial_num": int(trial_num), "start_time": start_time, "owner_name": owner_name}
|
|
|
+ train_new = {"position": position,
|
|
|
+ "finger_model": page_utils.fingermodel_trans[finger_model],
|
|
|
+ "trial_num": int(trial_num),
|
|
|
+ "start_time": start_time,
|
|
|
+ "owner_name": owner_name,
|
|
|
+ "virtual_feedback_rate": float(virtual_feedback_rate),
|
|
|
+ "model_path": model_path.name if model_path is not None else None}
|
|
|
train.create_train(conn, train_new)
|
|
|
os.system("python train_1.py")
|
|
|
return owner_name
|
|
|
|
|
|
-
|
|
|
def render():
|
|
|
st.set_page_config(
|
|
|
- page_title="train", page_icon=":chart_with_upwards_trend:"
|
|
|
+ page_title="手势训练", page_icon=":chart_with_upwards_trend:"
|
|
|
)
|
|
|
hide_footer()
|
|
|
|
|
|
- st.markdown("# Train")
|
|
|
+ st.markdown("# 手势训练")
|
|
|
st.sidebar.success("训练")
|
|
|
conn = st.connection("sql_app", type="sql")
|
|
|
train.create_table(conn)
|