Add the text_embeddings API
This commit is contained in:
@@ -8,6 +8,7 @@ use crate::dialogue::{Message, Role};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||
TextEmbeddingRequest, TextEmbeddingResponse,
|
||||
};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
@@ -171,4 +172,26 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn text_embeddings(
|
||||
&self,
|
||||
request: &TextEmbeddingRequest,
|
||||
) -> Result<TextEmbeddingResponse> {
|
||||
let model = "textembedding-gecko@003";
|
||||
let endpoint_url = format!(
|
||||
"https://{}/v1/projects/{}/locations/{}/publishers/google/models/{}:predict",
|
||||
self.api_endpoint, self.project_id, self.location_id, model,
|
||||
);
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
let txt_json = resp.text().await?;
|
||||
println!("{}", txt_json);
|
||||
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@ mod common;
|
||||
mod count_tokens;
|
||||
mod error;
|
||||
mod generate_content;
|
||||
mod text_embeddings;
|
||||
|
||||
pub use common::*;
|
||||
pub use count_tokens::*;
|
||||
pub use error::*;
|
||||
pub use generate_content::*;
|
||||
pub use text_embeddings::*;
|
||||
|
||||
43
src/types/text_embeddings.rs
Normal file
43
src/types/text_embeddings.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::Error;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingRequest {
|
||||
pub instances: Vec<TextEmbeddingRequestInstance>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingRequestInstance {
|
||||
pub content: String,
|
||||
pub task_type: String,
|
||||
pub title: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum TextEmbeddingResponse {
|
||||
Ok {
|
||||
predictions: Vec<TextEmbeddingPrediction>,
|
||||
},
|
||||
Error {
|
||||
error: Error,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingPrediction {
|
||||
pub embeddings: TextEmbeddingResult,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingResult {
|
||||
statistics: TextEmbeddingStatistics,
|
||||
values: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TextEmbeddingStatistics {
|
||||
truncated: bool,
|
||||
token_count: u32,
|
||||
}
|
||||
Reference in New Issue
Block a user