2_train.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. """train"""
  2. from datetime import datetime
  3. import os
  4. import streamlit as st
  5. from db import subject
  6. from db import train
  7. from components.remove_style import hide_footer
  8. import page_utils
  9. from device.peripheral.manager import get_serial_ports
  10. def _create_train(conn, subjects):
  11. with st.form("train_form"):
  12. st.write("创建训练")
  13. position = st.selectbox("训练部位", ['右手', '左手'])
  14. hand_com = st.selectbox("气动手COM口", get_serial_ports())
  15. finger_model = st.selectbox("气动手手势", list(page_utils.fingermodel_trans.keys()))
  16. trial_num = st.number_input("训练次数", value=10, step=5)
  17. owner_name = st.selectbox("用户", subjects.name.to_list())
  18. virtual_feedback_rate = st.number_input("假反馈比例", min_value=0., max_value=1., value=1., step=0.2)
  19. model_path = page_utils.file_selector(f'./static/models/{owner_name}')
  20. submitted = st.form_submit_button("开始训练")
  21. if submitted:
  22. start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
  23. train_new = {"position": position,
  24. "finger_model": page_utils.fingermodel_trans[finger_model],
  25. "trial_num": int(trial_num),
  26. "start_time": start_time,
  27. "owner_name": owner_name,
  28. "virtual_feedback_rate": float(virtual_feedback_rate),
  29. "model_path": model_path}
  30. train.create_train(conn, train_new)
  31. # run psychopy process
  32. os.system(f'python general_grasp_training.py --subj {owner_name} --n-trials {trial_num} --com {hand_com} --finger-model {page_utils.fingermodel_trans[finger_model]} --virtual-feedback-rate {virtual_feedback_rate} --model-path {str(model_path)}')
  33. return owner_name
  34. def render():
  35. st.set_page_config(
  36. page_title="手势训练", page_icon=":chart_with_upwards_trend:"
  37. )
  38. hide_footer()
  39. st.markdown("# 手势训练")
  40. st.sidebar.success("训练")
  41. conn = st.connection("sql_app", type="sql")
  42. train.create_table(conn)
  43. subjects = subject.get_subjects(conn)
  44. sub_name = _create_train(conn, subjects)
  45. if sub_name:
  46. trains = train.get_trains(conn, sub_name)
  47. st.write("# 训练列表")
  48. st.dataframe(trains)
  49. render()