davanstrien HF Staff commited on
Commit
8b71d33
·
1 Parent(s): 213c06e

tidy cache

Browse files
Files changed (1) hide show
  1. app.py +48 -33
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import datetime
2
  import os
3
  import copy
4
  from dataclasses import asdict, dataclass
@@ -29,8 +28,15 @@ from httpx import Client
29
  from httpx_caching import CachingClient
30
  from httpx_caching import OneDayCacheHeuristic
31
 
 
 
 
 
 
 
32
  client = Client()
33
 
 
34
  client = CachingClient(client, heuristic=OneDayCacheHeuristic())
35
 
36
 
@@ -57,7 +63,7 @@ def get_model_labels(model):
57
  class EngagementStats:
58
  likes: int
59
  downloads: int
60
- created_at: datetime.datetime
61
 
62
 
63
  def _get_engagement_stats(hub_id):
@@ -298,6 +304,7 @@ GENERIC_SCORES = generate_common_scores()
298
 
299
 
300
  # @cache.memoize(expire=60 * 60 * 24 * 3) # expires after 3 days
 
301
  def _basic_check(hub_id):
302
  data = ModelMetadata.from_hub(hub_id)
303
  score = 0
@@ -358,7 +365,7 @@ def create_query_url(query, skip=0):
358
  return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model"
359
 
360
 
361
- # @cache.memoize(expire=60 * 60 * 24 * 3) # expires after 3 days
362
  def get_results(query) -> Dict[Any, Any]:
363
  url = create_query_url(query)
364
  r = client.get(url)
@@ -390,8 +397,22 @@ def parse_single_result(result):
390
  }
391
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  def filter_search_results(
394
- results: List[Dict[Any, Any]], min_score=None, min_model_card_length=None
 
 
395
  ): # TODO make code more intuitive
396
  results = thread_map(parse_single_result, results)
397
  for i, parsed_result in tqdm(enumerate(results)):
@@ -418,10 +439,14 @@ def filter_search_results(
418
  yield parsed_result
419
 
420
 
421
- def sort_search_results(filtered_search_results):
 
 
 
 
422
  return sorted(
423
  list(filtered_search_results),
424
- key=lambda x: (x["metadata_score"], x["original_position"]),
425
  reverse=True,
426
  )
427
 
@@ -435,20 +460,12 @@ def find_context(text, query, window_size):
435
  # Get the start and end indices of the context window
436
  start = max(0, index - window_size)
437
  end = min(len(words), index + window_size + 1)
438
-
439
  return " ".join(words[start:end])
440
  except ValueError:
441
  return " ".join(words[:window_size])
442
 
443
 
444
- # single_result[
445
- # "text"
446
- # ] = "lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua."
447
-
448
- # results = [single_result] * 3
449
-
450
-
451
- def create_markdown(results):
452
  rows = []
453
  for result in results:
454
  row = f"""# [{result['name']}]({result['repo_hub_url']})
@@ -490,7 +507,6 @@ def _search_hub(
490
  # for result in filtered_results:
491
  # result_text = httpx.get(result["search_result_file_url"]).text
492
  # result["text"] = find_context(result_text, query, 100)
493
-
494
  # final_results.append(result)
495
  final_results = thread_map(get_result_card_snippet, filtered_results)
496
  percent_of_original = round(
@@ -532,23 +548,22 @@ with gr.Blocks() as demo:
532
  [query, min_metadata_score, mim_model_card_length],
533
  [filter_results, results_markdown],
534
  )
535
- with gr.Tab("Scoring metadata quality"):
536
- with gr.Row():
537
- gr.Markdown(
538
- f"""
539
- # Metadata quality scoring
540
- ```
541
- {COMMON_SCORES}
542
- ```
543
-
544
- For example, `TASK_TYPES_WITH_LANGUAGES` defines all the tasks for which it
545
- is expected to have language metadata associated with the model.
546
- ```
547
- {TASK_TYPES_WITH_LANGUAGES}
548
- ```
549
- """
550
- )
551
-
552
 
553
  demo.launch()
554
 
 
 
1
  import os
2
  import copy
3
  from dataclasses import asdict, dataclass
 
28
  from httpx_caching import CachingClient
29
  from httpx_caching import OneDayCacheHeuristic
30
 
31
+ from cachetools import cached, TTLCache
32
+ from datetime import timedelta
33
+ from datetime import datetime
34
+
35
+ cache = TTLCache(maxsize=500_000, ttl=timedelta(hours=24), timer=datetime.now)
36
+
37
  client = Client()
38
 
39
+
40
  client = CachingClient(client, heuristic=OneDayCacheHeuristic())
41
 
42
 
 
63
  class EngagementStats:
64
  likes: int
65
  downloads: int
66
+ created_at: datetime
67
 
68
 
69
  def _get_engagement_stats(hub_id):
 
304
 
305
 
306
  # @cache.memoize(expire=60 * 60 * 24 * 3) # expires after 3 days
307
+ @cached(cache)
308
  def _basic_check(hub_id):
309
  data = ModelMetadata.from_hub(hub_id)
310
  score = 0
 
365
  return f"https://huggingface.co/api/search/full-text?q={query}&limit=100&skip={skip}&type=model"
366
 
367
 
368
+ @cached(cache)
369
  def get_results(query) -> Dict[Any, Any]:
370
  url = create_query_url(query)
371
  r = client.get(url)
 
397
  }
398
 
399
 
400
+ def filter_for_license(results):
401
+ for result in results:
402
+ if result["is_licensed"]:
403
+ yield result
404
+
405
+
406
+ def filter_for_min_model_card_length(results, min_model_card_length):
407
+ for result in results:
408
+ if result["model_card_length"] > min_model_card_length:
409
+ yield result
410
+
411
+
412
  def filter_search_results(
413
+ results: List[Dict[Any, Any]],
414
+ min_score=None,
415
+ min_model_card_length=None,
416
  ): # TODO make code more intuitive
417
  results = thread_map(parse_single_result, results)
418
  for i, parsed_result in tqdm(enumerate(results)):
 
439
  yield parsed_result
440
 
441
 
442
+ def sort_search_results(
443
+ filtered_search_results,
444
+ first_sort="metadata_score",
445
+ second_sort="original_position", # TODO expose these in results
446
+ ):
447
  return sorted(
448
  list(filtered_search_results),
449
+ key=lambda x: (x[first_sort], x[second_sort]),
450
  reverse=True,
451
  )
452
 
 
460
  # Get the start and end indices of the context window
461
  start = max(0, index - window_size)
462
  end = min(len(words), index + window_size + 1)
 
463
  return " ".join(words[start:end])
464
  except ValueError:
465
  return " ".join(words[:window_size])
466
 
467
 
468
+ def create_markdown(results): # TODO move to separate file
 
 
 
 
 
 
 
469
  rows = []
470
  for result in results:
471
  row = f"""# [{result['name']}]({result['repo_hub_url']})
 
507
  # for result in filtered_results:
508
  # result_text = httpx.get(result["search_result_file_url"]).text
509
  # result["text"] = find_context(result_text, query, 100)
 
510
  # final_results.append(result)
511
  final_results = thread_map(get_result_card_snippet, filtered_results)
512
  percent_of_original = round(
 
548
  [query, min_metadata_score, mim_model_card_length],
549
  [filter_results, results_markdown],
550
  )
551
+ # with gr.Tab("Scoring metadata quality"):
552
+ # with gr.Row():
553
+ # gr.Markdown(
554
+ # f"""
555
+ # # Metadata quality scoring
556
+ # ```
557
+ # {COMMON_SCORES}
558
+ # ```
559
+
560
+ # For example, `TASK_TYPES_WITH_LANGUAGES` defines all the tasks for which it
561
+ # is expected to have language metadata associated with the model.
562
+ # ```
563
+ # {TASK_TYPES_WITH_LANGUAGES}
564
+ # ```
565
+ # """
566
+ # )
 
567
 
568
  demo.launch()
569