embedchain
77 строк · 2.5 Кб
1import hashlib2from unittest.mock import MagicMock3
4import pytest5
6from embedchain.loaders.mysql import MySQLLoader7
8
9@pytest.fixture10def mysql_loader(mocker):11with mocker.patch("mysql.connector.connection.MySQLConnection"):12config = {13"host": "localhost",14"port": "3306",15"user": "your_username",16"password": "your_password",17"database": "your_database",18}19loader = MySQLLoader(config=config)20yield loader21
22
23def test_mysql_loader_initialization(mysql_loader):24assert mysql_loader.config is not None25assert mysql_loader.connection is not None26assert mysql_loader.cursor is not None27
28
29def test_mysql_loader_invalid_config():30with pytest.raises(ValueError, match="Invalid sql config: None"):31MySQLLoader(config=None)32
33
34def test_mysql_loader_setup_loader_successful(mysql_loader):35assert mysql_loader.connection is not None36assert mysql_loader.cursor is not None37
38
39def test_mysql_loader_setup_loader_connection_error(mysql_loader, mocker):40mocker.patch("mysql.connector.connection.MySQLConnection", side_effect=IOError("Mocked connection error"))41with pytest.raises(ValueError, match="Unable to connect with the given config:"):42mysql_loader._setup_loader(config={})43
44
45def test_mysql_loader_check_query_successful(mysql_loader):46query = "SELECT * FROM table"47mysql_loader._check_query(query=query)48
49
50def test_mysql_loader_check_query_invalid(mysql_loader):51with pytest.raises(ValueError, match="Invalid mysql query: 123"):52mysql_loader._check_query(query=123)53
54
55def test_mysql_loader_load_data_successful(mysql_loader, mocker):56mock_cursor = MagicMock()57mocker.patch.object(mysql_loader, "cursor", mock_cursor)58mock_cursor.fetchall.return_value = [(1, "data1"), (2, "data2")]59
60query = "SELECT * FROM table"61result = mysql_loader.load_data(query)62
63assert "doc_id" in result64assert "data" in result65assert len(result["data"]) == 266assert result["data"][0]["meta_data"]["url"] == query67assert result["data"][1]["meta_data"]["url"] == query68
69doc_id = hashlib.sha256((query + ", ".join([d["content"] for d in result["data"]])).encode()).hexdigest()70
71assert result["doc_id"] == doc_id72assert mock_cursor.execute.called_with(query)73
74
75def test_mysql_loader_load_data_invalid_query(mysql_loader):76with pytest.raises(ValueError, match="Invalid mysql query: 123"):77mysql_loader.load_data(query=123)78