From c8ea777a2c3d123cf69494aaf8f8e6ee044fde7b Mon Sep 17 00:00:00 2001 From: cheng zhen Date: Sat, 4 Jan 2025 15:23:37 +0800 Subject: [PATCH] =?UTF-8?q?test:=20=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- {test => tests}/get_veri_code_test.py | 0 tests/test_cursor_auth_manager.py | 93 +++++++++++++++++++++++++++ 2 files changed, 93 insertions(+) rename {test => tests}/get_veri_code_test.py (100%) create mode 100644 tests/test_cursor_auth_manager.py diff --git a/test/get_veri_code_test.py b/tests/get_veri_code_test.py similarity index 100% rename from test/get_veri_code_test.py rename to tests/get_veri_code_test.py diff --git a/tests/test_cursor_auth_manager.py b/tests/test_cursor_auth_manager.py new file mode 100644 index 00000000..70b1bfa6 --- /dev/null +++ b/tests/test_cursor_auth_manager.py @@ -0,0 +1,93 @@ +import unittest +import os +import sqlite3 +from unittest.mock import patch, MagicMock +from cursor_auth_manager import CursorAuthManager + + +class TestCursorAuthManager(unittest.TestCase): + def setUp(self): + # 创建临时测试数据库 + self.test_db_path = "test_state.vscdb" + + # 创建测试数据库和表 + self.conn = sqlite3.connect(self.test_db_path) + self.cursor = self.conn.cursor() + self.cursor.execute( + """ + CREATE TABLE IF NOT EXISTS itemTable ( + key TEXT PRIMARY KEY, + value TEXT + ) + """ + ) + self.conn.commit() + + # 模拟 CursorAuthManager 的数据库路径 + with patch.object(CursorAuthManager, "__init__", return_value=None): + self.auth_manager = CursorAuthManager() + self.auth_manager.db_path = self.test_db_path + + def tearDown(self): + # 清理测试数据库 + self.conn.close() + if os.path.exists(self.test_db_path): + os.remove(self.test_db_path) + + def test_update_auth_new_values(self): + """测试更新新的认证信息""" + result = self.auth_manager.update_auth( + email="test@example.com", + access_token="test_access_token", + refresh_token="test_refresh_token", + ) + + self.assertTrue(result) + + # 验证数据是否正确写入 + cursor = self.conn.cursor() + + # 检查邮箱 + cursor.execute( + "SELECT value FROM itemTable WHERE key=?", ("cursorAuth/cachedEmail",) + ) + self.assertEqual(cursor.fetchone()[0], "test@example.com") + + # 检查访问令牌 + cursor.execute( + "SELECT value FROM itemTable WHERE key=?", ("cursorAuth/accessToken",) + ) + self.assertEqual(cursor.fetchone()[0], "test_access_token") + + # 检查刷新令牌 + cursor.execute( + "SELECT value FROM itemTable WHERE key=?", ("cursorAuth/refreshToken",) + ) + self.assertEqual(cursor.fetchone()[0], "test_refresh_token") + + # 检查注册类型 + cursor.execute( + "SELECT value FROM itemTable WHERE key=?", ("cursorAuth/cachedSignUpType",) + ) + self.assertEqual(cursor.fetchone()[0], "Auth_0") + + def test_update_auth_partial_update(self): + """测试只更新部分认证信息""" + result = self.auth_manager.update_auth(email="test@example.com") + + self.assertTrue(result) + + cursor = self.conn.cursor() + cursor.execute( + "SELECT value FROM itemTable WHERE key=?", ("cursorAuth/cachedEmail",) + ) + self.assertEqual(cursor.fetchone()[0], "test@example.com") + + def test_update_auth_no_values(self): + """测试不提供任何更新值的情况""" + result = self.auth_manager.update_auth() + self.assertFalse(result) + + +if __name__ == "__main__": + unittest.main()