1
0

2_train.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. """train"""
  2. from datetime import datetime
  3. import os
  4. import streamlit as st
  5. from db.models import subject
  6. from db.models import train
  7. from components.remove_style import hide_footer
  8. def _create_train(conn, subjects):
  9. with st.form("train_form"):
  10. st.write("创建训练")
  11. position = st.text_input("训练部位")
  12. trial_num = st.number_input("训练次数", value=1, step=1)
  13. owner_name = st.selectbox("用户", subjects.name.to_list())
  14. submitted = st.form_submit_button("开始训练")
  15. if submitted:
  16. start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
  17. train_new = {"position": position, "trial_num": int(trial_num), "start_time": start_time, "owner_name": owner_name}
  18. train.create_train(conn, train_new)
  19. # TODO: set absolute path
  20. os.system("python train_1.py")
  21. return owner_name
  22. def render():
  23. st.set_page_config(
  24. page_title="train", page_icon=":chart_with_upwards_trend:"
  25. )
  26. hide_footer()
  27. st.markdown("# Train")
  28. st.sidebar.success("训练")
  29. conn = st.connection("sql_app", type="sql")
  30. train.create_table(conn)
  31. subjects = subject.get_subjects(conn)
  32. sub_name = _create_train(conn, subjects)
  33. if sub_name:
  34. trains = train.get_trains(conn, sub_name)
  35. st.write("# 训练列表")
  36. st.dataframe(trains)
  37. render()