Browse Source

Feat: 创建训练和测试页面

dk 1 year ago
parent
commit
149f98024f

+ 4 - 1
.gitignore

@@ -167,4 +167,7 @@ cython_debug/
 #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 #  and can be added to the global gitignore or merged into this file.  For a more nuclear
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
+#.idea/
+
+model/*
+!model/.gitkeep

+ 27 - 0
backend/db/models/test.py

@@ -0,0 +1,27 @@
+"""train model"""
+import streamlit as st
+from sqlalchemy.sql import text
+
+
+def create_table(conn):
+    with conn.session as s:
+        s.execute(text('CREATE TABLE IF NOT EXISTS test (position TEXT, finger_model TEXT, start_time DATETIME, owner_name TEXT, model_path TEXT);'))
+        s.commit()
+
+
+def get_tests(conn, sub_name):
+    tests = conn.query('select * from test where owner_name = :owner', ttl=0.05, params={'owner': sub_name})
+    return tests 
+
+
+def create_test(conn, test_form):
+    with conn.session as s:
+        s.execute(
+            text('INSERT INTO test (position, finger_model, start_time, owner_name, model_path) VALUES (:position, :finger_model, :start_time, :owner_name, :model_path);'),
+            params=dict(position=test_form['position'],
+                        finger_model=test_form['finger_model'], 
+                        start_time=test_form['start_time'], 
+                        owner_name=test_form['owner_name'],
+                        model_path=test_form['model_path'])
+        )
+        s.commit()

+ 9 - 4
backend/db/models/train.py

@@ -4,9 +4,8 @@ from sqlalchemy.sql import text
 
 
 def create_table(conn):
-    # TODO: set up parameters for train
     with conn.session as s:
-        s.execute(text('CREATE TABLE IF NOT EXISTS train (position TEXT, trial_num INTEGER, start_time DATETIME, owner_name TEXT);'))
+        s.execute(text('CREATE TABLE IF NOT EXISTS train (position TEXT, finger_model TEXT, trial_num INTEGER, start_time DATETIME, owner_name TEXT, virtual_feedback_rate FLOAT, model_path TEXT);'))
         s.commit()
 
 
@@ -18,7 +17,13 @@ def get_trains(conn, sub_name):
 def create_train(conn, train_form):
     with conn.session as s:
         s.execute(
-            text('INSERT INTO train (position, trial_num, start_time, owner_name) VALUES (:position, :trial_num, :start_time, :owner_name);'),
-            params=dict(position=train_form['position'], trial_num=train_form['trial_num'], start_time=train_form['start_time'], owner_name=train_form['owner_name'])
+            text('INSERT INTO train (position, finger_model, trial_num, start_time, owner_name, virtual_feedback_rate, model_path) VALUES (:position, :finger_model, :trial_num, :start_time, :owner_name, :virtual_feedback_rate, :model_path);'),
+            params=dict(position=train_form['position'],
+                        finger_model=train_form['finger_model'], 
+                        trial_num=train_form['trial_num'], 
+                        start_time=train_form['start_time'], 
+                        owner_name=train_form['owner_name'],
+                        virtual_feedback_rate=train_form['virtual_feedback_rate'],
+                        model_path=train_form['model_path'])
         )
         s.commit()

+ 0 - 0
backend/model/.gitkeep


+ 24 - 0
backend/page_utils.py

@@ -0,0 +1,24 @@
+import streamlit as st
+import os
+
+
+fingermodel_trans = {
+    '一般抓握': 'flex',
+    '柱状抓握': 'cylinder',
+    '球状抓握': 'ball',
+    '两指对捏': 'double',
+    '三指对捏': 'treble'
+}
+
+
+def file_selector(folder_path='.'):
+    try:
+        os.mkdir(folder_path)
+    except FileExistsError:
+        pass
+    filenames = os.listdir(folder_path)
+    selected_filename = st.selectbox('Select a file', filenames)
+    if selected_filename is not None:
+        return os.path.join(folder_path, selected_filename)
+    else:
+        return None

+ 15 - 7
backend/pages/2_train.py

@@ -7,31 +7,39 @@ import streamlit as st
 from db.models import subject
 from db.models import train
 from components.remove_style import hide_footer
+import page_utils
 
 
 def _create_train(conn, subjects):
     with st.form("train_form"):
         st.write("创建训练")
-        # TODO: redefine parameters
-        position = st.text_input("训练部位")
-        trial_num = st.number_input("训练次数", value=1, step=1)
+        position = st.selectbox("训练部位", ['左手', '右手'])
+        finger_model = st.selectbox("气动手手势", list(page_utils.fingermodel_trans.keys()))
+        trial_num = st.number_input("训练次数", value=10, step=5)
         owner_name = st.selectbox("用户", subjects.name.to_list())
+        virtual_feedback_rate = st.number_input("假反馈比例", value=0., step=0.2)
+        model_path = page_utils.file_selector(os.path.join(f'./model/{owner_name}'))
         submitted = st.form_submit_button("开始训练")
         if submitted:
             start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
-            train_new = {"position": position, "trial_num": int(trial_num), "start_time": start_time, "owner_name": owner_name}
+            train_new = {"position": position, 
+                         "finger_model": page_utils.fingermodel_trans[finger_model], 
+                         "trial_num": int(trial_num), 
+                         "start_time": start_time, 
+                         "owner_name": owner_name,
+                         "virtual_feedback_rate": float(virtual_feedback_rate),
+                         "model_path": model_path.name if model_path is not None else None}
             train.create_train(conn, train_new)
             os.system("python train_1.py")
             return owner_name
 
-
 def render():
     st.set_page_config(
-        page_title="train", page_icon=":chart_with_upwards_trend:"
+        page_title="手势训练", page_icon=":chart_with_upwards_trend:"
     )
     hide_footer()
 
-    st.markdown("# Train")
+    st.markdown("# 手势训练")
     st.sidebar.success("训练")
     conn = st.connection("sql_app", type="sql")
     train.create_table(conn)

+ 34 - 16
backend/pages/3_test.py

@@ -1,29 +1,47 @@
 """IMU放置于人体模型后的motion capture分析,包括离线、在线。读入或在线接入数据进行波形绘制、数据分析及分析结果的人体模型渲染和结果图表"""
 import streamlit as st
-
+import os
+from datetime import datetime
+from db.models import test, subject
 from components.remove_style import hide_footer
 
-
-def on_line():
-    st.button("在线")
-    st.sidebar.success("在线")
+import page_utils
 
 
-def off_line():
-    st.button("离线")
-    st.sidebar.success("离线")
+def _create_test(conn, subjects):
+    with st.form("test_form"):
+        st.write("创建训练")
+        position = st.selectbox("训练部位", ['左手', '右手'])
+        finger_model_names = st.multiselect("气动手手势", list(page_utils.fingermodel_trans.keys()))
+        owner_name = st.selectbox("用户", subjects.name.to_list())
+        model_path = page_utils.file_selector(os.path.join(f'./model/{owner_name}'))
+        submitted = st.form_submit_button("开始训练")
+        if submitted:
+            start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
+            test_new = {"position": position, 
+                         "finger_model": ','.join([page_utils.fingermodel_trans[f] for f in finger_model_names]), 
+                         "start_time": start_time, 
+                         "owner_name": owner_name,
+                         "model_path": model_path.name if model_path is not None else None}
+            test.create_test(conn, test_new)
+            return owner_name
 
 
 def render():
-    st.set_page_config(page_title="test", page_icon=":running:")
+    st.set_page_config(
+        page_title="自由手势训练", page_icon=":chart_with_upwards_trend:"
+    )
     hide_footer()
-    st.markdown("# Test")
-
-    on_off_switch = st.toggle("离线/在线")
-    if on_off_switch:
-        on_line()
-    else:
-        off_line()
 
+    st.markdown("# 手势训练")
+    st.sidebar.success("训练")
+    conn = st.connection("sql_app", type="sql")
+    test.create_table(conn)
+    subjects = subject.get_subjects(conn)
+    sub_name = _create_test(conn, subjects)
+    if sub_name:
+        tests = test.get_tests(conn, sub_name)
+        st.write("# 训练列表")
+        st.dataframe(tests)
 
 render()

+ 10 - 0
backend/schemas/hand_peripheral.py

@@ -5,6 +5,16 @@ from pydantic import BaseModel
 from pydantic import Field
 
 
+FINGERMODEL_IDS = {
+    'rest': 0,
+    'cylinder': 1,
+    'ball': 2,
+    'flex': 3,
+    'double': 4,
+    'treble': 5
+}
+
+
 class ChannelName(int, Enum):
     CHANNEL_A = 0x01
     CHANNEL_B = 0x02