milvus-io_bootcamp
135 строк · 5.2 Кб
1import io.milvus.client.{MilvusClient, MilvusServiceClient}
2import io.milvus.grpc.{DataType, ImportResponse}
3import io.milvus.param.bulkinsert.{BulkInsertParam, GetBulkInsertStateParam}
4import io.milvus.param.collection.{CreateCollectionParam, FieldType}
5import io.milvus.param.{ConnectParam, R, RpcStatus}
6import org.apache.spark.SparkConf
7import org.apache.spark.sql.types._
8import org.apache.spark.sql.{SaveMode, SparkSession}
9import org.slf4j.LoggerFactory
10import zilliztech.spark.milvus.MilvusOptions.{MILVUS_COLLECTION_NAME, MILVUS_HOST, MILVUS_PORT, MILVUS_TOKEN, MILVUS_URI}
11import org.apache.hadoop.fs.{FileSystem, Path}
12import java.net.URI
13import org.apache.log4j.Logger
14
15import scala.collection.JavaConverters._
16
17import java.util
18
19var logger = Logger.getLogger(this.getClass())
20
21val sparkConf = new SparkConf().setMaster("local")
22val spark = SparkSession.builder().config(sparkConf).getOrCreate()
23// Fill in user's Milvus instance credentials.
24val host = "127.0.0.1"
25val port = 19530
26val username = "root"
27val password = "Milvus"
28// Specify the target Milvus collection name.
29val collectionName = "spark_milvus_test"
30// This file simulates a dataframe from user's vector generation job or a Delta table that contains vectors.
31val filePath = "/Volumes/zilliz_test/default/sample_vectors/dim32_1k.json"
32// The S3 bucket is an internal bucket of the Milvus instance, which the user has full control of.
33// The user needs to set up this bucket as "storage crenditial" by following
34// the instruction at https://docs.databricks.com/en/connect/unity-catalog/storage-credentials.html#step-2-give-databricks-the-iam-role-details
35// Here the user can specify the directory in the bucket to store vector data.
36// The vectors will be output to the s3 bucket in specific format that can be loaded to Zilliz Cloud efficiently.
37val outputPath = "s3://your-s3-bucket-name/filesaa/spark_output"
38
39// 1. Create Milvus collection through Milvus SDK
40val connectParam: ConnectParam = ConnectParam.newBuilder
41.withHost(host)
42.withPort(port)
43.withAuthorization(username, password)
44.build
45
46val client: MilvusClient = new MilvusServiceClient(connectParam)
47
48val field1Name: String = "id_field"
49val field2Name: String = "str_field"
50val field3Name: String = "float_vector_field"
51val fieldsSchema: util.List[FieldType] = new util.ArrayList[FieldType]
52
53fieldsSchema.add(FieldType.newBuilder
54.withPrimaryKey(true)
55.withAutoID(false)
56.withDataType(DataType.Int64)
57.withName(field1Name)
58.build
59)
60fieldsSchema.add(FieldType.newBuilder
61.withDataType(DataType.VarChar)
62.withName(field2Name)
63.withMaxLength(65535)
64.build
65)
66fieldsSchema.add(FieldType.newBuilder
67.withDataType(DataType.FloatVector)
68.withName(field3Name)
69.withDimension(32)
70.build
71)
72
73// create collection
74val createParam: CreateCollectionParam = CreateCollectionParam.newBuilder
75.withCollectionName(collectionName)
76.withFieldTypes(fieldsSchema)
77.build
78
79val createR: R[RpcStatus] = client.createCollection(createParam)
80
81logger.info(s"create collection ${collectionName} resp: ${createR.toString}")
82
83// 2. Read data from file to build vector dataframe. The schema of the dataframe must logically match the schema of vector db.
84val df = spark.read
85.schema(new StructType()
86.add(field1Name, IntegerType)
87.add(field2Name, StringType)
88.add(field3Name, ArrayType(FloatType), false))
89.json(filePath)
90
91// 3. Store all vector data in the s3 bucket to prepare for loading.
92df.repartition(1)
93.write
94.format("mjson")
95.mode("overwrite")
96.save(outputPath)
97
98// 4. As the vector data has been stored in the s3 bucket as files, here we list the directory and get the file paths
99// to prepare input of Zilliz Cloud Import Data API call.
100val hadoopConfig = spark.sparkContext.hadoopConfiguration
101val directory = new Path(outputPath)
102val fs = FileSystem.get(directory.toUri, hadoopConfig)
103val files = fs.listStatus(directory)
104val ouputPath = files.filter(file => {
105file.getPath.getName.endsWith(".json")
106})(0)
107def extractPathWithoutBucket(s3Path: String): String = {
108val uri = new URI(s3Path)
109val pathWithoutBucket = uri.getPath.drop(1) // Drop the leading '/'
110pathWithoutBucket
111}
112val ouputFilePathWithoutBucket = extractPathWithoutBucket(ouputPath.getPath.toString)
113
114// 5. Make a call to Milvus bulkinsert API.
115val bulkInsertFiles:List[String] = List(ouputFilePathWithoutBucket)
116val bulkInsertParam: BulkInsertParam = BulkInsertParam.newBuilder
117.withCollectionName(collectionName)
118.withFiles(bulkInsertFiles.asJava)
119.build
120
121val bulkInsertR: R[ImportResponse] = client.bulkInsert(bulkInsertParam)
122logger.info(s"bulkinsert ${collectionName} resp: ${bulkInsertR.toString}")
123val taskId: Long = bulkInsertR.getData.getTasksList.get(0)
124
125var bulkloadState = client.getBulkInsertState(GetBulkInsertStateParam.newBuilder.withTask(taskId).build)
126while (bulkloadState.getData.getState.getNumber != 1 &&
127bulkloadState.getData.getState.getNumber != 6 &&
128bulkloadState.getData.getState.getNumber != 7 ) {
129bulkloadState = client.getBulkInsertState(GetBulkInsertStateParam.newBuilder.withTask(taskId).build)
130logger.info(s"bulkinsert ${collectionName} resp: ${bulkInsertR.toString} state: ${bulkloadState}")
131Thread.sleep(3000)
132}
133if (bulkloadState.getData.getState.getNumber != 6) {
134logger.error(s"bulkinsert failed ${collectionName} state: ${bulkloadState}")
135}