Browse Source

Refactor: remove streamlit, use command line directly

dk 1 year ago
parent
commit
072dad3f14

+ 12 - 9
.vscode/launch.json

@@ -5,24 +5,27 @@
     "version": "0.2.0",
     "configurations": [
         {
-            "name": "Python: 当前文件",
+            "name": "General training paradigm",
             "type": "python",
             "request": "launch",
-            "program": "${file}",
+            "program": "general_grasp_training.py",
             "console": "integratedTerminal",
             "cwd": "${workspaceFolder}/backend",
-            "justMyCode": true
+            "justMyCode": true,
+            "args": ["--subj", "ylj", 
+            "--n-trials", "15", 
+            "--com", "COM3", 
+            "-fm", "flex", 
+            "-vfr", "1.0", 
+            "--model-path", "./static/models/ylj/baseline_rest+cylinder_11-16-2023-16-38-32.pkl"]
         },
         {
-            "name": "Python: Streamlit",
+            "name": "Python: 当前文件",
             "type": "python",
             "request": "launch",
-            "module": "streamlit",
+            "program": "${file}",
+            "console": "integratedTerminal",
             "cwd": "${workspaceFolder}/backend",
-            "args": [
-                "run",
-                "main.py"
-            ],
             "justMyCode": true
         },
         {

+ 0 - 8
backend/.streamlit/config.toml

@@ -1,8 +0,0 @@
-[general]
-email = ""
-
-[server]
-maxUploadSize = 3000
-
-[client]
-toolbarMode = "minimal"

+ 0 - 2
backend/.streamlit/secrets.toml

@@ -1,2 +0,0 @@
-[connections.sql_app]
-url = "sqlite:///sql_app.db"

+ 0 - 2
backend/bci_core/model.py

@@ -5,10 +5,8 @@ from pyriemann.estimation import Covariances, BlockCovariances
 from pyriemann.tangentspace import TangentSpace
 
 from sklearn.ensemble import StackingClassifier
-from sklearn.preprocessing import FunctionTransformer
 from sklearn.pipeline import make_pipeline
 from sklearn.base import BaseEstimator, TransformerMixin
-from sklearn.preprocessing import StandardScaler
 
 from mne.decoding import Vectorizer
 

+ 0 - 11
backend/components/remove_style.py

@@ -1,11 +0,0 @@
-"""remove some streamlit style"""
-import streamlit as st
-
-
-def hide_footer():
-    hide_st_style = """
-                        <style>
-                        footer {visibility: hidden;}
-                        </style>
-                    """
-    st.markdown(hide_st_style, unsafe_allow_html=True)

+ 0 - 23
backend/db/subject.py

@@ -1,23 +0,0 @@
-"""subject model"""
-import streamlit as st
-from sqlalchemy.sql import text
-
-
-def create_table(conn):
-    with conn.session as s:
-        s.execute(text('CREATE TABLE IF NOT EXISTS subject (name TEXT, gender TEXT, birthday DATE, create_time DATETIME);'))
-        s.commit()
-
-
-def get_subjects(conn):
-    subjects = conn.query('select * from subject', ttl=0.05)
-    return subjects
-
-
-def create_subject(conn, subject_form):
-    with conn.session as s:
-        s.execute(
-            text('INSERT INTO subject (name, gender, birthday, create_time) VALUES (:name, :gender, :birthday, :create_time);'),
-            params=dict(name=subject_form['name'], gender=subject_form['gender'], birthday=subject_form['birthday'], create_time=subject_form['create_time'])
-        )
-        s.commit()

+ 0 - 27
backend/db/test.py

@@ -1,27 +0,0 @@
-"""train model"""
-import streamlit as st
-from sqlalchemy.sql import text
-
-
-def create_table(conn):
-    with conn.session as s:
-        s.execute(text('CREATE TABLE IF NOT EXISTS test (position TEXT, finger_model TEXT, start_time DATETIME, owner_name TEXT, model_path TEXT);'))
-        s.commit()
-
-
-def get_tests(conn, sub_name):
-    tests = conn.query('select * from test where owner_name = :owner', ttl=0.05, params={'owner': sub_name})
-    return tests 
-
-
-def create_test(conn, test_form):
-    with conn.session as s:
-        s.execute(
-            text('INSERT INTO test (position, finger_model, start_time, owner_name, model_path) VALUES (:position, :finger_model, :start_time, :owner_name, :model_path);'),
-            params=dict(position=test_form['position'],
-                        finger_model=test_form['finger_model'], 
-                        start_time=test_form['start_time'], 
-                        owner_name=test_form['owner_name'],
-                        model_path=test_form['model_path'])
-        )
-        s.commit()

+ 0 - 29
backend/db/train.py

@@ -1,29 +0,0 @@
-"""train model"""
-import streamlit as st
-from sqlalchemy.sql import text
-
-
-def create_table(conn):
-    with conn.session as s:
-        s.execute(text('CREATE TABLE IF NOT EXISTS train (position TEXT, finger_model TEXT, trial_num INTEGER, start_time DATETIME, owner_name TEXT, virtual_feedback_rate FLOAT, model_path TEXT);'))
-        s.commit()
-
-
-def get_trains(conn, sub_name):
-    trains = conn.query('select * from train where owner_name = :owner', ttl=0.05, params={'owner': sub_name})
-    return trains 
-
-
-def create_train(conn, train_form):
-    with conn.session as s:
-        s.execute(
-            text('INSERT INTO train (position, finger_model, trial_num, start_time, owner_name, virtual_feedback_rate, model_path) VALUES (:position, :finger_model, :trial_num, :start_time, :owner_name, :virtual_feedback_rate, :model_path);'),
-            params=dict(position=train_form['position'],
-                        finger_model=train_form['finger_model'], 
-                        trial_num=train_form['trial_num'], 
-                        start_time=train_form['start_time'], 
-                        owner_name=train_form['owner_name'],
-                        virtual_feedback_rate=train_form['virtual_feedback_rate'],
-                        model_path=train_form['model_path'])
-        )
-        s.commit()

+ 0 - 25
backend/page_utils.py

@@ -1,25 +0,0 @@
-import streamlit as st
-import os
-
-
-fingermodel_trans = {
-    '一般抓握': 'flex',
-    '柱状抓握': 'cylinder',
-    '球状抓握': 'ball',
-    '两指对捏': 'double',
-    '三指对捏': 'treble'
-}
-
-
-def file_selector(folder_path='.'):
-    try:
-        os.mkdir(folder_path)
-    except FileExistsError:
-        pass
-    filenames = os.listdir(folder_path)
-    filenames = filter(lambda x: x.endwith('.pkl'), filenames)
-    selected_filename = st.selectbox('Select a file', filenames)
-    if selected_filename is not None:
-        return os.path.join(folder_path, selected_filename)
-    else:
-        return None

+ 0 - 57
backend/pages/2_train.py

@@ -1,57 +0,0 @@
-"""train"""
-from datetime import datetime
-import os
-
-import streamlit as st
-
-from db import subject
-from db import train
-from components.remove_style import hide_footer
-import page_utils
-from device.peripheral.manager import get_serial_ports
-
-
-def _create_train(conn, subjects):
-    with st.form("train_form"):
-        st.write("创建训练")
-        position = st.selectbox("训练部位", ['右手', '左手'])
-        hand_com = st.selectbox("气动手COM口", get_serial_ports())
-        finger_model = st.selectbox("气动手手势", list(page_utils.fingermodel_trans.keys()))
-        trial_num = st.number_input("训练次数", value=10, step=5)
-        owner_name = st.selectbox("用户", subjects.name.to_list())
-        virtual_feedback_rate = st.number_input("假反馈比例", min_value=0., max_value=1., value=1., step=0.2)
-        model_path = page_utils.file_selector(f'./static/models/{owner_name}')
-        submitted = st.form_submit_button("开始训练")
-        if submitted:
-            start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
-            train_new = {"position": position, 
-                         "finger_model": page_utils.fingermodel_trans[finger_model], 
-                         "trial_num": int(trial_num), 
-                         "start_time": start_time, 
-                         "owner_name": owner_name,
-                         "virtual_feedback_rate": float(virtual_feedback_rate),
-                         "model_path": model_path}
-            train.create_train(conn, train_new)
-            # run psychopy process
-            os.system(f'python general_grasp_training.py --subj {owner_name} --n-trials {trial_num} --com {hand_com} --finger-model {page_utils.fingermodel_trans[finger_model]} --virtual-feedback-rate {virtual_feedback_rate} --model-path {str(model_path)}')
-            return owner_name
-
-def render():
-    st.set_page_config(
-        page_title="手势训练", page_icon=":chart_with_upwards_trend:"
-    )
-    hide_footer()
-
-    st.markdown("# 手势训练")
-    st.sidebar.success("训练")
-    conn = st.connection("sql_app", type="sql")
-    train.create_table(conn)
-    subjects = subject.get_subjects(conn)
-    sub_name = _create_train(conn, subjects)
-    if sub_name:
-        trains = train.get_trains(conn, sub_name)
-        st.write("# 训练列表")
-        st.dataframe(trains)
-
-
-render()

+ 0 - 51
backend/pages/3_test.py

@@ -1,51 +0,0 @@
-"""IMU放置于人体模型后的motion capture分析,包括离线、在线。读入或在线接入数据进行波形绘制、数据分析及分析结果的人体模型渲染和结果图表"""
-import streamlit as st
-import os
-from datetime import datetime
-from db import subject
-from db import test
-from components.remove_style import hide_footer
-from device.peripheral.manager import get_serial_ports
-
-import page_utils
-
-
-def _create_test(conn, subjects):
-    with st.form("test_form"):
-        st.write("创建训练")
-        position = st.selectbox("训练部位", ['右手', '左手'])
-        hand_com = st.selectbox("气动手COM口", get_serial_ports())
-        finger_model_names = st.multiselect("气动手手势", list(page_utils.fingermodel_trans.keys()))
-        owner_name = st.selectbox("用户", subjects.name.to_list())
-        model_path = page_utils.file_selector(os.path.join(f'./static/models/{owner_name}'))
-        submitted = st.form_submit_button("开始训练")
-        if submitted:
-            start_time = datetime.strptime(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), "%Y-%m-%d %H:%M:%S")
-            test_new = {"position": position, 
-                         "finger_model": ','.join([page_utils.fingermodel_trans[f] for f in finger_model_names]), 
-                         "start_time": start_time, 
-                         "owner_name": owner_name,
-                         "model_path": model_path}
-            test.create_test(conn, test_new)
-            # TODO: run a psychopy process
-            return owner_name
-
-
-def render():
-    st.set_page_config(
-        page_title="自由手势训练", page_icon=":chart_with_upwards_trend:"
-    )
-    hide_footer()
-
-    st.markdown("# 自由手势训练")
-    st.sidebar.success("训练")
-    conn = st.connection("sql_app", type="sql")
-    test.create_table(conn)
-    subjects = subject.get_subjects(conn)
-    sub_name = _create_test(conn, subjects)
-    if sub_name:
-        tests = test.get_tests(conn, sub_name)
-        st.write("# 训练列表")
-        st.dataframe(tests)
-
-render()

+ 0 - 6
environment.yml

@@ -6,18 +6,12 @@ dependencies:
   - pip
   - pip:
       - av==10.0.0
-      - fastapi==0.104.1
-      - func_timeout==4.3.5
       - mne==1.5.1
       - pydantic==2.4.2
       - pyedflib==0.1.36
       - pyserial==3.5
       - pytest==7.4.3
       - pyriemann==0.5
-      - service==0.6.0
-      - SQLAlchemy==2.0.23
-      - starlette==0.27.0
-      - streamlit==1.28.1
       - joblib==1.3.2
       - opencv_python~=4.8.1
       - numpy~=1.26

+ 0 - 6
requirements.txt

@@ -1,6 +1,4 @@
 av==10.0.0
-fastapi==0.104.1
-func_timeout==4.3.5
 joblib==1.3.2
 matplotlib==3.8.1
 mne==1.5.1
@@ -12,9 +10,5 @@ pytest==7.4.3
 pyriemann==0.5
 scikit_learn==1.3.2
 scipy==1.11.3
-service==0.6.0
-SQLAlchemy==2.0.23
-starlette==0.27.0
-streamlit==1.28.1
 opencv_python~=4.8.1
 pyyaml~=6.0.1