Adds a count_token methods to the Gemini Clent
This commit is contained in:
@@ -7,8 +7,8 @@ use reqwest_eventsource::{Event, EventSource};
|
||||
use crate::dialogue::{Message, Role};
|
||||
use crate::error::{Error, Result};
|
||||
use crate::prelude::{
|
||||
Candidate, Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig,
|
||||
TextEmbeddingRequest, TextEmbeddingResponse,
|
||||
Candidate, Content, CountTokensRequest, CountTokensResponse, GenerateContentRequest,
|
||||
GenerateContentResponse, GenerationConfig, TextEmbeddingRequest, TextEmbeddingResponse,
|
||||
};
|
||||
use crate::{prelude::Part, token_provider::TokenProvider};
|
||||
|
||||
@@ -191,6 +191,30 @@ impl<T: TokenProvider + Clone> GeminiClient<T> {
|
||||
.send()
|
||||
.await?;
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("text_embeddings response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str::<TextEmbeddingResponse>(&txt_json)?)
|
||||
}
|
||||
|
||||
pub async fn count_tokens(
|
||||
&self,
|
||||
request: &CountTokensRequest,
|
||||
model: &str,
|
||||
) -> Result<CountTokensResponse> {
|
||||
let endpoint_url = format!(
|
||||
"https://{}/v1beta1/projects/{}/locations/{}/publishers/google/models/{}:countTokens",
|
||||
self.api_endpoint, self.project_id, self.location_id, model,
|
||||
);
|
||||
let access_token = self.token_provider.get_token(AUTH_SCOPE).await?;
|
||||
let resp = self
|
||||
.client
|
||||
.post(&endpoint_url)
|
||||
.bearer_auth(access_token)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let txt_json = resp.text().await?;
|
||||
tracing::debug!("count_tokens response: {:?}", txt_json);
|
||||
Ok(serde_json::from_str(&txt_json)?)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user