123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- """train"""
- from datetime import datetime
- import os
- import streamlit as st
- from db.models import subject
- from db.models import train
- from components.remove_style import hide_footer
- import page_utils
- def _create_train(conn, subjects):
- with st.form("train_form"):
- st.write("创建训练")
- position = st.selectbox("训练部位", ['左手', '右手'])
- finger_model = st.selectbox("气动手手势", list(page_utils.fingermodel_trans.keys()))
- trial_num = st.number_input("训练次数", value=10, step=5)
- owner_name = st.selectbox("用户", subjects.name.to_list())
- virtual_feedback_rate = st.number_input("假反馈比例", value=0., step=0.2)
- model_path = page_utils.file_selector(os.path.join(f'./model/{owner_name}'))
- submitted = st.form_submit_button("开始训练")
- if submitted:
- start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
- train_new = {"position": position,
- "finger_model": page_utils.fingermodel_trans[finger_model],
- "trial_num": int(trial_num),
- "start_time": start_time,
- "owner_name": owner_name,
- "virtual_feedback_rate": float(virtual_feedback_rate),
- "model_path": model_path.name if model_path is not None else None}
- train.create_train(conn, train_new)
- os.system("python train_1.py")
- return owner_name
- def render():
- st.set_page_config(
- page_title="手势训练", page_icon=":chart_with_upwards_trend:"
- )
- hide_footer()
- st.markdown("# 手势训练")
- st.sidebar.success("训练")
- conn = st.connection("sql_app", type="sql")
- train.create_table(conn)
- subjects = subject.get_subjects(conn)
- sub_name = _create_train(conn, subjects)
- if sub_name:
- trains = train.get_trains(conn, sub_name)
- st.write("# 训练列表")
- st.dataframe(trains)
- render()
|